"""Chart-to-CSV extraction using Granite Vision. Converts chart images to tabular CSV data using ibm-granite/granite-vision-4.1-4b. Same ZeroGPU streaming pattern as infer_vision_qa.py — see that module for details. """ import threading from collections.abc import Generator import spaces import torch from PIL import Image from transformers import TextIteratorStreamer from model_loader import load_model, load_processor @spaces.GPU(duration=120) def extract_csv_stream(image: Image.Image) -> Generator[str, None, None]: """Stream CSV extraction token-by-token from a chart image. Runs inside @spaces.GPU. model.generate() runs in a plain background thread. Args: image: PIL Image of a chart or table. Yields: Accumulated CSV text after each token. """ processor, model = load_model() if processor is None or model is None: yield "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6" return try: image = image.convert("RGB") conversation = [{"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": ""}, ]}] text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) inputs = processor(text=text, images=image, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) def _generate() -> None: with torch.inference_mode(): model.generate(**inputs, max_new_tokens=4096, use_cache=True, streamer=streamer) thread = threading.Thread(target=_generate) thread.start() accumulated = "" for token in streamer: if token: accumulated += token yield accumulated thread.join() except Exception as e: # noqa: BLE001 import traceback traceback.print_exc() yield f"Error: {e!s}" @spaces.GPU(duration=120) def extract_csv(image: Image.Image) -> str: """Extract CSV data from a chart image using Granite Vision (non-streaming). Args: image: PIL Image of a chart or table. Returns: CSV-formatted text extracted from the chart. """ processor, model = load_model() if processor is None or model is None: return "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6" try: image = image.convert("RGB") conversation = [{"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": ""}, ]}] text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) inputs = processor(text=text, images=image, return_tensors="pt").to(model.device) max_new_tokens = 4096 with torch.inference_mode(): outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True) gen = outputs[0, inputs["input_ids"].shape[1]:] result = processor.decode(gen, skip_special_tokens=True) if len(gen) >= max_new_tokens: result += "\n\n[Max token limit reached — response may be truncated]" return result except Exception as e: # noqa: BLE001 import traceback traceback.print_exc() return f"Error: {e!s}"