Spaces:
Running on Zero
Running on Zero
| """Chart-to-CSV extraction using Granite Vision. | |
| Converts chart images to tabular CSV data using ibm-granite/granite-vision-4.1-4b. | |
| Same ZeroGPU streaming pattern as infer_vision_qa.py — see that module for details. | |
| """ | |
| import threading | |
| from collections.abc import Generator | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from transformers import TextIteratorStreamer | |
| from model_loader import load_model, load_processor | |
| def extract_csv_stream(image: Image.Image) -> Generator[str, None, None]: | |
| """Stream CSV extraction token-by-token from a chart image. | |
| Runs inside @spaces.GPU. model.generate() runs in a plain background thread. | |
| Args: | |
| image: PIL Image of a chart or table. | |
| Yields: | |
| Accumulated CSV text after each token. | |
| """ | |
| processor, model = load_model() | |
| if processor is None or model is None: | |
| yield "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6" | |
| return | |
| try: | |
| image = image.convert("RGB") | |
| conversation = [{"role": "user", "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": "<chart2csv>"}, | |
| ]}] | |
| text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=text, images=image, return_tensors="pt").to(model.device) | |
| streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| def _generate() -> None: | |
| with torch.inference_mode(): | |
| model.generate(**inputs, max_new_tokens=4096, use_cache=True, streamer=streamer) | |
| thread = threading.Thread(target=_generate) | |
| thread.start() | |
| accumulated = "" | |
| for token in streamer: | |
| if token: | |
| accumulated += token | |
| yield accumulated | |
| thread.join() | |
| except Exception as e: # noqa: BLE001 | |
| import traceback | |
| traceback.print_exc() | |
| yield f"Error: {e!s}" | |
| def extract_csv(image: Image.Image) -> str: | |
| """Extract CSV data from a chart image using Granite Vision (non-streaming). | |
| Args: | |
| image: PIL Image of a chart or table. | |
| Returns: | |
| CSV-formatted text extracted from the chart. | |
| """ | |
| processor, model = load_model() | |
| if processor is None or model is None: | |
| return "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6" | |
| try: | |
| image = image.convert("RGB") | |
| conversation = [{"role": "user", "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": "<chart2csv>"}, | |
| ]}] | |
| text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=text, images=image, return_tensors="pt").to(model.device) | |
| max_new_tokens = 4096 | |
| with torch.inference_mode(): | |
| outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True) | |
| gen = outputs[0, inputs["input_ids"].shape[1]:] | |
| result = processor.decode(gen, skip_special_tokens=True) | |
| if len(gen) >= max_new_tokens: | |
| result += "\n\n[Max token limit reached — response may be truncated]" | |
| return result | |
| except Exception as e: # noqa: BLE001 | |
| import traceback | |
| traceback.print_exc() | |
| return f"Error: {e!s}" | |