Deploybot
Deploy from stable branch
49574d5
"""Granite Vision image and chart Q&A.
Supports single-turn and multi-turn conversations using ibm-granite/granite-vision-4.1-4b.
ZeroGPU streaming pattern:
- answer_question_stream is decorated with @spaces.GPU so the entire generator
runs inside the ZeroGPU subprocess with GPU access.
- model.generate() is called in a plain background thread (no @spaces.GPU needed —
the subprocess already has GPU access process-wide).
- The streamer is never passed across process boundaries, avoiding pickling errors.
- model_loader loads the model on CPU then moves to CUDA to avoid caching_allocator_warmup
triggering torch._C._cuda_init() before ZeroGPU can intercept it.
"""
import threading
from collections.abc import Generator
from typing import Any
import spaces
import torch
from PIL import Image
from transformers import TextIteratorStreamer
from model_loader import load_model, load_processor
@spaces.GPU(duration=120)
def answer_question_stream(
image: Image.Image,
question: str,
conversation_history: list[dict[str, Any]],
current_image_path: str | None, # noqa: ARG001
) -> Generator[str, None, None]:
"""Stream an answer token-by-token about an image using Granite Vision.
Runs inside @spaces.GPU so CUDA is available throughout. model.generate()
runs in a background thread (GPU is process-wide in the ZeroGPU subprocess).
Args:
image: PIL Image to query.
question: Question string to ask.
conversation_history: Prior conversation turns (role/content dicts).
current_image_path: Unused — kept for API compatibility.
Yields:
Accumulated answer text after each token.
"""
processor, model = load_model()
if processor is None or model is None:
yield f"[STUB] Question: {question}\n\nThis is a placeholder response. Model not loaded."
return
try:
image = image.convert("RGB")
if not conversation_history:
new_user_turn = {"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": question},
]}
else:
new_user_turn = {"role": "user", "content": [
{"type": "text", "text": question},
]}
conversation = [*conversation_history, new_user_turn]
text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
def _generate() -> None:
with torch.inference_mode():
model.generate(**inputs, max_new_tokens=4096, use_cache=True, streamer=streamer)
thread = threading.Thread(target=_generate)
thread.start()
accumulated = ""
for token in streamer:
if token:
accumulated += token
yield accumulated
thread.join()
except Exception as e: # noqa: BLE001
import traceback
traceback.print_exc()
yield f"Error: {e!s}"
@spaces.GPU(duration=120)
def answer_question(
image: Image.Image,
question: str,
conversation_history: list[dict[str, Any]],
current_image_path: str | None, # noqa: ARG001
) -> tuple[str, list[dict[str, Any]], str | None]:
"""Answer a question about an image using Granite Vision (non-streaming).
Args:
image: PIL Image to query.
question: Question string to ask.
conversation_history: Prior conversation turns (role/content dicts).
current_image_path: Unused — kept for API compatibility.
Returns:
Tuple of (answer_text, updated_history, None).
"""
processor, model = load_model()
if processor is None or model is None:
stub = f"[STUB] Question: {question}\n\nThis is a placeholder response. Model not loaded."
return stub, conversation_history, None
try:
image = image.convert("RGB")
if not conversation_history:
new_user_turn = {"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": question},
]}
else:
new_user_turn = {"role": "user", "content": [
{"type": "text", "text": question},
]}
conversation = [*conversation_history, new_user_turn]
text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
max_new_tokens = 4096
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
gen = outputs[0, inputs["input_ids"].shape[1]:]
answer = processor.decode(gen, skip_special_tokens=True)
if len(gen) >= max_new_tokens:
answer += "\n\n[Max token limit reached — response may be truncated]"
updated_history = [
*conversation_history,
new_user_turn,
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
]
return answer, updated_history, None
except Exception as e: # noqa: BLE001
import traceback
traceback.print_exc()
return f"Error: {e!s}", conversation_history, None