File size: 5,458 Bytes
49574d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Granite Vision image and chart Q&A.

Supports single-turn and multi-turn conversations using ibm-granite/granite-vision-4.1-4b.

ZeroGPU streaming pattern:
- answer_question_stream is decorated with @spaces.GPU so the entire generator
  runs inside the ZeroGPU subprocess with GPU access.
- model.generate() is called in a plain background thread (no @spaces.GPU needed —
  the subprocess already has GPU access process-wide).
- The streamer is never passed across process boundaries, avoiding pickling errors.
- model_loader loads the model on CPU then moves to CUDA to avoid caching_allocator_warmup
  triggering torch._C._cuda_init() before ZeroGPU can intercept it.
"""

import threading
from collections.abc import Generator
from typing import Any

import spaces
import torch
from PIL import Image
from transformers import TextIteratorStreamer

from model_loader import load_model, load_processor


@spaces.GPU(duration=120)
def answer_question_stream(
    image: Image.Image,
    question: str,
    conversation_history: list[dict[str, Any]],
    current_image_path: str | None,  # noqa: ARG001
) -> Generator[str, None, None]:
    """Stream an answer token-by-token about an image using Granite Vision.

    Runs inside @spaces.GPU so CUDA is available throughout. model.generate()
    runs in a background thread (GPU is process-wide in the ZeroGPU subprocess).

    Args:
        image: PIL Image to query.
        question: Question string to ask.
        conversation_history: Prior conversation turns (role/content dicts).
        current_image_path: Unused — kept for API compatibility.

    Yields:
        Accumulated answer text after each token.
    """
    processor, model = load_model()

    if processor is None or model is None:
        yield f"[STUB] Question: {question}\n\nThis is a placeholder response. Model not loaded."
        return

    try:
        image = image.convert("RGB")

        if not conversation_history:
            new_user_turn = {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": question},
            ]}
        else:
            new_user_turn = {"role": "user", "content": [
                {"type": "text", "text": question},
            ]}

        conversation = [*conversation_history, new_user_turn]
        text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
        inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)

        streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)

        def _generate() -> None:
            with torch.inference_mode():
                model.generate(**inputs, max_new_tokens=4096, use_cache=True, streamer=streamer)

        thread = threading.Thread(target=_generate)
        thread.start()

        accumulated = ""
        for token in streamer:
            if token:
                accumulated += token
                yield accumulated

        thread.join()

    except Exception as e:  # noqa: BLE001
        import traceback
        traceback.print_exc()
        yield f"Error: {e!s}"


@spaces.GPU(duration=120)
def answer_question(
    image: Image.Image,
    question: str,
    conversation_history: list[dict[str, Any]],
    current_image_path: str | None,  # noqa: ARG001
) -> tuple[str, list[dict[str, Any]], str | None]:
    """Answer a question about an image using Granite Vision (non-streaming).

    Args:
        image: PIL Image to query.
        question: Question string to ask.
        conversation_history: Prior conversation turns (role/content dicts).
        current_image_path: Unused — kept for API compatibility.

    Returns:
        Tuple of (answer_text, updated_history, None).
    """
    processor, model = load_model()

    if processor is None or model is None:
        stub = f"[STUB] Question: {question}\n\nThis is a placeholder response. Model not loaded."
        return stub, conversation_history, None

    try:
        image = image.convert("RGB")

        if not conversation_history:
            new_user_turn = {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": question},
            ]}
        else:
            new_user_turn = {"role": "user", "content": [
                {"type": "text", "text": question},
            ]}

        conversation = [*conversation_history, new_user_turn]
        text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
        inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)

        max_new_tokens = 4096
        with torch.inference_mode():
            outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)

        gen = outputs[0, inputs["input_ids"].shape[1]:]
        answer = processor.decode(gen, skip_special_tokens=True)
        if len(gen) >= max_new_tokens:
            answer += "\n\n[Max token limit reached — response may be truncated]"

        updated_history = [
            *conversation_history,
            new_user_turn,
            {"role": "assistant", "content": [{"type": "text", "text": answer}]},
        ]

        return answer, updated_history, None

    except Exception as e:  # noqa: BLE001
        import traceback
        traceback.print_exc()
        return f"Error: {e!s}", conversation_history, None