Spaces:
Running on Zero
Running on Zero
| """Granite Vision image and chart Q&A. | |
| Supports single-turn and multi-turn conversations using ibm-granite/granite-vision-4.1-4b. | |
| ZeroGPU streaming pattern: | |
| - answer_question_stream is decorated with @spaces.GPU so the entire generator | |
| runs inside the ZeroGPU subprocess with GPU access. | |
| - model.generate() is called in a plain background thread (no @spaces.GPU needed — | |
| the subprocess already has GPU access process-wide). | |
| - The streamer is never passed across process boundaries, avoiding pickling errors. | |
| - model_loader loads the model on CPU then moves to CUDA to avoid caching_allocator_warmup | |
| triggering torch._C._cuda_init() before ZeroGPU can intercept it. | |
| """ | |
| import threading | |
| from collections.abc import Generator | |
| from typing import Any | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from transformers import TextIteratorStreamer | |
| from model_loader import load_model, load_processor | |
| def answer_question_stream( | |
| image: Image.Image, | |
| question: str, | |
| conversation_history: list[dict[str, Any]], | |
| current_image_path: str | None, # noqa: ARG001 | |
| ) -> Generator[str, None, None]: | |
| """Stream an answer token-by-token about an image using Granite Vision. | |
| Runs inside @spaces.GPU so CUDA is available throughout. model.generate() | |
| runs in a background thread (GPU is process-wide in the ZeroGPU subprocess). | |
| Args: | |
| image: PIL Image to query. | |
| question: Question string to ask. | |
| conversation_history: Prior conversation turns (role/content dicts). | |
| current_image_path: Unused — kept for API compatibility. | |
| Yields: | |
| Accumulated answer text after each token. | |
| """ | |
| processor, model = load_model() | |
| if processor is None or model is None: | |
| yield f"[STUB] Question: {question}\n\nThis is a placeholder response. Model not loaded." | |
| return | |
| try: | |
| image = image.convert("RGB") | |
| if not conversation_history: | |
| new_user_turn = {"role": "user", "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": question}, | |
| ]} | |
| else: | |
| new_user_turn = {"role": "user", "content": [ | |
| {"type": "text", "text": question}, | |
| ]} | |
| conversation = [*conversation_history, new_user_turn] | |
| text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=text, images=image, return_tensors="pt").to(model.device) | |
| streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| def _generate() -> None: | |
| with torch.inference_mode(): | |
| model.generate(**inputs, max_new_tokens=4096, use_cache=True, streamer=streamer) | |
| thread = threading.Thread(target=_generate) | |
| thread.start() | |
| accumulated = "" | |
| for token in streamer: | |
| if token: | |
| accumulated += token | |
| yield accumulated | |
| thread.join() | |
| except Exception as e: # noqa: BLE001 | |
| import traceback | |
| traceback.print_exc() | |
| yield f"Error: {e!s}" | |
| def answer_question( | |
| image: Image.Image, | |
| question: str, | |
| conversation_history: list[dict[str, Any]], | |
| current_image_path: str | None, # noqa: ARG001 | |
| ) -> tuple[str, list[dict[str, Any]], str | None]: | |
| """Answer a question about an image using Granite Vision (non-streaming). | |
| Args: | |
| image: PIL Image to query. | |
| question: Question string to ask. | |
| conversation_history: Prior conversation turns (role/content dicts). | |
| current_image_path: Unused — kept for API compatibility. | |
| Returns: | |
| Tuple of (answer_text, updated_history, None). | |
| """ | |
| processor, model = load_model() | |
| if processor is None or model is None: | |
| stub = f"[STUB] Question: {question}\n\nThis is a placeholder response. Model not loaded." | |
| return stub, conversation_history, None | |
| try: | |
| image = image.convert("RGB") | |
| if not conversation_history: | |
| new_user_turn = {"role": "user", "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": question}, | |
| ]} | |
| else: | |
| new_user_turn = {"role": "user", "content": [ | |
| {"type": "text", "text": question}, | |
| ]} | |
| conversation = [*conversation_history, new_user_turn] | |
| text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=text, images=image, return_tensors="pt").to(model.device) | |
| max_new_tokens = 4096 | |
| with torch.inference_mode(): | |
| outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True) | |
| gen = outputs[0, inputs["input_ids"].shape[1]:] | |
| answer = processor.decode(gen, skip_special_tokens=True) | |
| if len(gen) >= max_new_tokens: | |
| answer += "\n\n[Max token limit reached — response may be truncated]" | |
| updated_history = [ | |
| *conversation_history, | |
| new_user_turn, | |
| {"role": "assistant", "content": [{"type": "text", "text": answer}]}, | |
| ] | |
| return answer, updated_history, None | |
| except Exception as e: # noqa: BLE001 | |
| import traceback | |
| traceback.print_exc() | |
| return f"Error: {e!s}", conversation_history, None | |