Deploybot
Deploy from stable branch
49574d5
"""Chart-to-CSV extraction using Granite Vision.
Converts chart images to tabular CSV data using ibm-granite/granite-vision-4.1-4b.
Same ZeroGPU streaming pattern as infer_vision_qa.py — see that module for details.
"""
import threading
from collections.abc import Generator
import spaces
import torch
from PIL import Image
from transformers import TextIteratorStreamer
from model_loader import load_model, load_processor
@spaces.GPU(duration=120)
def extract_csv_stream(image: Image.Image) -> Generator[str, None, None]:
"""Stream CSV extraction token-by-token from a chart image.
Runs inside @spaces.GPU. model.generate() runs in a plain background thread.
Args:
image: PIL Image of a chart or table.
Yields:
Accumulated CSV text after each token.
"""
processor, model = load_model()
if processor is None or model is None:
yield "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6"
return
try:
image = image.convert("RGB")
conversation = [{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": "<chart2csv>"},
]}]
text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
def _generate() -> None:
with torch.inference_mode():
model.generate(**inputs, max_new_tokens=4096, use_cache=True, streamer=streamer)
thread = threading.Thread(target=_generate)
thread.start()
accumulated = ""
for token in streamer:
if token:
accumulated += token
yield accumulated
thread.join()
except Exception as e: # noqa: BLE001
import traceback
traceback.print_exc()
yield f"Error: {e!s}"
@spaces.GPU(duration=120)
def extract_csv(image: Image.Image) -> str:
"""Extract CSV data from a chart image using Granite Vision (non-streaming).
Args:
image: PIL Image of a chart or table.
Returns:
CSV-formatted text extracted from the chart.
"""
processor, model = load_model()
if processor is None or model is None:
return "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6"
try:
image = image.convert("RGB")
conversation = [{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": "<chart2csv>"},
]}]
text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
max_new_tokens = 4096
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
gen = outputs[0, inputs["input_ids"].shape[1]:]
result = processor.decode(gen, skip_special_tokens=True)
if len(gen) >= max_new_tokens:
result += "\n\n[Max token limit reached — response may be truncated]"
return result
except Exception as e: # noqa: BLE001
import traceback
traceback.print_exc()
return f"Error: {e!s}"