Spaces:
Running on Zero
Running on Zero
File size: 5,458 Bytes
49574d5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """Granite Vision image and chart Q&A.
Supports single-turn and multi-turn conversations using ibm-granite/granite-vision-4.1-4b.
ZeroGPU streaming pattern:
- answer_question_stream is decorated with @spaces.GPU so the entire generator
runs inside the ZeroGPU subprocess with GPU access.
- model.generate() is called in a plain background thread (no @spaces.GPU needed —
the subprocess already has GPU access process-wide).
- The streamer is never passed across process boundaries, avoiding pickling errors.
- model_loader loads the model on CPU then moves to CUDA to avoid caching_allocator_warmup
triggering torch._C._cuda_init() before ZeroGPU can intercept it.
"""
import threading
from collections.abc import Generator
from typing import Any
import spaces
import torch
from PIL import Image
from transformers import TextIteratorStreamer
from model_loader import load_model, load_processor
@spaces.GPU(duration=120)
def answer_question_stream(
image: Image.Image,
question: str,
conversation_history: list[dict[str, Any]],
current_image_path: str | None, # noqa: ARG001
) -> Generator[str, None, None]:
"""Stream an answer token-by-token about an image using Granite Vision.
Runs inside @spaces.GPU so CUDA is available throughout. model.generate()
runs in a background thread (GPU is process-wide in the ZeroGPU subprocess).
Args:
image: PIL Image to query.
question: Question string to ask.
conversation_history: Prior conversation turns (role/content dicts).
current_image_path: Unused — kept for API compatibility.
Yields:
Accumulated answer text after each token.
"""
processor, model = load_model()
if processor is None or model is None:
yield f"[STUB] Question: {question}\n\nThis is a placeholder response. Model not loaded."
return
try:
image = image.convert("RGB")
if not conversation_history:
new_user_turn = {"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": question},
]}
else:
new_user_turn = {"role": "user", "content": [
{"type": "text", "text": question},
]}
conversation = [*conversation_history, new_user_turn]
text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
def _generate() -> None:
with torch.inference_mode():
model.generate(**inputs, max_new_tokens=4096, use_cache=True, streamer=streamer)
thread = threading.Thread(target=_generate)
thread.start()
accumulated = ""
for token in streamer:
if token:
accumulated += token
yield accumulated
thread.join()
except Exception as e: # noqa: BLE001
import traceback
traceback.print_exc()
yield f"Error: {e!s}"
@spaces.GPU(duration=120)
def answer_question(
image: Image.Image,
question: str,
conversation_history: list[dict[str, Any]],
current_image_path: str | None, # noqa: ARG001
) -> tuple[str, list[dict[str, Any]], str | None]:
"""Answer a question about an image using Granite Vision (non-streaming).
Args:
image: PIL Image to query.
question: Question string to ask.
conversation_history: Prior conversation turns (role/content dicts).
current_image_path: Unused — kept for API compatibility.
Returns:
Tuple of (answer_text, updated_history, None).
"""
processor, model = load_model()
if processor is None or model is None:
stub = f"[STUB] Question: {question}\n\nThis is a placeholder response. Model not loaded."
return stub, conversation_history, None
try:
image = image.convert("RGB")
if not conversation_history:
new_user_turn = {"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": question},
]}
else:
new_user_turn = {"role": "user", "content": [
{"type": "text", "text": question},
]}
conversation = [*conversation_history, new_user_turn]
text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
max_new_tokens = 4096
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
gen = outputs[0, inputs["input_ids"].shape[1]:]
answer = processor.decode(gen, skip_special_tokens=True)
if len(gen) >= max_new_tokens:
answer += "\n\n[Max token limit reached — response may be truncated]"
updated_history = [
*conversation_history,
new_user_turn,
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
]
return answer, updated_history, None
except Exception as e: # noqa: BLE001
import traceback
traceback.print_exc()
return f"Error: {e!s}", conversation_history, None
|