"""Granite Vision image and chart Q&A. Supports single-turn and multi-turn conversations using ibm-granite/granite-vision-4.1-4b. ZeroGPU streaming pattern: - answer_question_stream is decorated with @spaces.GPU so the entire generator runs inside the ZeroGPU subprocess with GPU access. - model.generate() is called in a plain background thread (no @spaces.GPU needed — the subprocess already has GPU access process-wide). - The streamer is never passed across process boundaries, avoiding pickling errors. - model_loader loads the model on CPU then moves to CUDA to avoid caching_allocator_warmup triggering torch._C._cuda_init() before ZeroGPU can intercept it. """ import threading from collections.abc import Generator from typing import Any import spaces import torch from PIL import Image from transformers import TextIteratorStreamer from model_loader import load_model, load_processor @spaces.GPU(duration=120) def answer_question_stream( image: Image.Image, question: str, conversation_history: list[dict[str, Any]], current_image_path: str | None, # noqa: ARG001 ) -> Generator[str, None, None]: """Stream an answer token-by-token about an image using Granite Vision. Runs inside @spaces.GPU so CUDA is available throughout. model.generate() runs in a background thread (GPU is process-wide in the ZeroGPU subprocess). Args: image: PIL Image to query. question: Question string to ask. conversation_history: Prior conversation turns (role/content dicts). current_image_path: Unused — kept for API compatibility. Yields: Accumulated answer text after each token. """ processor, model = load_model() if processor is None or model is None: yield f"[STUB] Question: {question}\n\nThis is a placeholder response. Model not loaded." return try: image = image.convert("RGB") if not conversation_history: new_user_turn = {"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": question}, ]} else: new_user_turn = {"role": "user", "content": [ {"type": "text", "text": question}, ]} conversation = [*conversation_history, new_user_turn] text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) inputs = processor(text=text, images=image, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) def _generate() -> None: with torch.inference_mode(): model.generate(**inputs, max_new_tokens=4096, use_cache=True, streamer=streamer) thread = threading.Thread(target=_generate) thread.start() accumulated = "" for token in streamer: if token: accumulated += token yield accumulated thread.join() except Exception as e: # noqa: BLE001 import traceback traceback.print_exc() yield f"Error: {e!s}" @spaces.GPU(duration=120) def answer_question( image: Image.Image, question: str, conversation_history: list[dict[str, Any]], current_image_path: str | None, # noqa: ARG001 ) -> tuple[str, list[dict[str, Any]], str | None]: """Answer a question about an image using Granite Vision (non-streaming). Args: image: PIL Image to query. question: Question string to ask. conversation_history: Prior conversation turns (role/content dicts). current_image_path: Unused — kept for API compatibility. Returns: Tuple of (answer_text, updated_history, None). """ processor, model = load_model() if processor is None or model is None: stub = f"[STUB] Question: {question}\n\nThis is a placeholder response. Model not loaded." return stub, conversation_history, None try: image = image.convert("RGB") if not conversation_history: new_user_turn = {"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": question}, ]} else: new_user_turn = {"role": "user", "content": [ {"type": "text", "text": question}, ]} conversation = [*conversation_history, new_user_turn] text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) inputs = processor(text=text, images=image, return_tensors="pt").to(model.device) max_new_tokens = 4096 with torch.inference_mode(): outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True) gen = outputs[0, inputs["input_ids"].shape[1]:] answer = processor.decode(gen, skip_special_tokens=True) if len(gen) >= max_new_tokens: answer += "\n\n[Max token limit reached — response may be truncated]" updated_history = [ *conversation_history, new_user_turn, {"role": "assistant", "content": [{"type": "text", "text": answer}]}, ] return answer, updated_history, None except Exception as e: # noqa: BLE001 import traceback traceback.print_exc() return f"Error: {e!s}", conversation_history, None