File size: 3,408 Bytes
49574d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Chart-to-CSV extraction using Granite Vision.

Converts chart images to tabular CSV data using ibm-granite/granite-vision-4.1-4b.

Same ZeroGPU streaming pattern as infer_vision_qa.py — see that module for details.
"""

import threading
from collections.abc import Generator

import spaces
import torch
from PIL import Image
from transformers import TextIteratorStreamer

from model_loader import load_model, load_processor


@spaces.GPU(duration=120)
def extract_csv_stream(image: Image.Image) -> Generator[str, None, None]:
    """Stream CSV extraction token-by-token from a chart image.

    Runs inside @spaces.GPU. model.generate() runs in a plain background thread.

    Args:
        image: PIL Image of a chart or table.

    Yields:
        Accumulated CSV text after each token.
    """
    processor, model = load_model()

    if processor is None or model is None:
        yield "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6"
        return

    try:
        image = image.convert("RGB")
        conversation = [{"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": "<chart2csv>"},
        ]}]
        text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
        inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)

        streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)

        def _generate() -> None:
            with torch.inference_mode():
                model.generate(**inputs, max_new_tokens=4096, use_cache=True, streamer=streamer)

        thread = threading.Thread(target=_generate)
        thread.start()

        accumulated = ""
        for token in streamer:
            if token:
                accumulated += token
                yield accumulated

        thread.join()

    except Exception as e:  # noqa: BLE001
        import traceback
        traceback.print_exc()
        yield f"Error: {e!s}"


@spaces.GPU(duration=120)
def extract_csv(image: Image.Image) -> str:
    """Extract CSV data from a chart image using Granite Vision (non-streaming).

    Args:
        image: PIL Image of a chart or table.

    Returns:
        CSV-formatted text extracted from the chart.
    """
    processor, model = load_model()

    if processor is None or model is None:
        return "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6"

    try:
        image = image.convert("RGB")
        conversation = [{"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": "<chart2csv>"},
        ]}]
        text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
        inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)

        max_new_tokens = 4096
        with torch.inference_mode():
            outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)

        gen = outputs[0, inputs["input_ids"].shape[1]:]
        result = processor.decode(gen, skip_special_tokens=True)
        if len(gen) >= max_new_tokens:
            result += "\n\n[Max token limit reached — response may be truncated]"
        return result

    except Exception as e:  # noqa: BLE001
        import traceback
        traceback.print_exc()
        return f"Error: {e!s}"