Spaces:
Running on Zero
Running on Zero
| """Granite Vision Document Intelligence Demo. | |
| Upload a PDF or image to explore Granite-Vision-4.1-4B capabilities including | |
| Chart2CSV, Chart2Code, Chart2Summary, Table Extraction, and Image Q&A. | |
| """ | |
| # Monkey-patch gradio_client to handle bool JSON Schema values. | |
| # gradio 5.x emits additionalProperties: false/true (valid JSON Schema) | |
| # but gradio_client 1.5.x does not guard against bool in get_type(), | |
| # causing TypeError on every request to the /info endpoint. | |
| try: | |
| import gradio_client.utils as _gcu | |
| _orig_get_type = _gcu.get_type | |
| _orig_j2p = _gcu._json_schema_to_python_type | |
| def _patched_get_type(schema): # noqa: ANN001, ANN202 | |
| if not isinstance(schema, dict): | |
| return "unknown" | |
| return _orig_get_type(schema) | |
| def _patched_j2p(schema, defs=None): # noqa: ANN001, ANN202 | |
| if not isinstance(schema, dict): | |
| return "any" if schema else "unknown" | |
| return _orig_j2p(schema, defs) | |
| _gcu.get_type = _patched_get_type | |
| _gcu._json_schema_to_python_type = _patched_j2p | |
| except Exception: # noqa: BLE001 | |
| pass | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| from PIL import Image | |
| from crops import extract_figures | |
| from document_parser import parse_document | |
| from infer_chart2csv import extract_csv_stream | |
| from infer_vision_qa import answer_question_stream | |
| from model_loader import load_processor | |
| from pdf_io import load_pdf_pages | |
| from themes.research_monochrome import theme | |
| from ui_state import create_initial_state, hash_bytes, page_cache, parse_cache | |
| # Pre-load the processor at startup (CPU-only, no GPU needed). | |
| # This avoids paying the processor load cost on the first user request. | |
| load_processor() | |
| TITLE = "Granite Vision: Document Intelligence" | |
| DESCRIPTION = ( | |
| "Upload a PDF, Word, Excel, PowerPoint, or image to explore Granite-Vision-4.1-4B's document intelligence capabilities — " | |
| "including Chart2Summary, Chart2CSV, Chart2Code, Table Extraction, and Image Description — " | |
| "with automatic Docling-powered parsing for documents and direct inference on uploaded images." | |
| ) | |
| IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".jfif", ".png", ".bmp", ".dib", ".gif", ".tif", ".tiff", ".webp"} | |
| OFFICE_EXTENSIONS = {".docx", ".xlsx", ".pptx"} | |
| css_file_path = Path(Path(__file__).parent / "app.css") | |
| head_file_path = Path(Path(__file__).parent / "app_head.html") | |
| def _is_image_file(file_path: str) -> bool: | |
| """Check whether a file path points to a supported image format.""" | |
| ext = os.path.splitext(file_path)[1].lower() | |
| return ext in IMAGE_EXTENSIONS | |
| def _is_office_file(file_path: str) -> bool: | |
| """Check whether a file path points to a supported Office format (DOCX/XLSX/PPTX).""" | |
| ext = os.path.splitext(file_path)[1].lower() | |
| return ext in OFFICE_EXTENSIONS | |
| def process_upload(file_path: str, session_state: dict[str, Any]) -> tuple: | |
| """Parse an uploaded PDF or load an image and extract figures. | |
| Args: | |
| file_path: Path to the uploaded file. | |
| session_state: Current Gradio session state dictionary. | |
| Returns: | |
| Tuple of (status, html_content, fig_status, fig_caption, fig_image, session_state). | |
| """ | |
| max_pages = 20 | |
| session_state["current_figure_index"] = 0 | |
| session_state["conversation_history"] = [] | |
| session_state["current_image_path"] = None | |
| if not file_path: | |
| return "Please upload a PDF, Office document, or image.", "No document loaded", "No figures", "", None, session_state | |
| try: | |
| with open(file_path, "rb") as f: | |
| file_bytes = f.read() | |
| file_hash = hash_bytes(file_bytes) | |
| session_state["uploaded_file_hash"] = file_hash | |
| session_state["uploaded_file_bytes"] = file_bytes | |
| if _is_image_file(file_path): | |
| image = Image.open(file_path).convert("RGB") | |
| figures_info = [{"image": image, "page": 0, "bbox": None, "caption": ""}] | |
| session_state["page_images"] = [image] | |
| session_state["parsed_result"] = {} | |
| session_state["figures_info"] = figures_info | |
| session_state["selected_figure"] = figures_info[0] | |
| return ( | |
| "Image loaded successfully.\nNumber of figures: 1.", | |
| "Image uploaded directly (no document parsing needed)", | |
| "Figure 1 of 1 (Page 1)", | |
| "", | |
| image, | |
| session_state, | |
| ) | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| is_office = _is_office_file(file_path) | |
| fmt_label = file_ext.lstrip(".").upper() | |
| status_lines = [f"{fmt_label} loaded successfully."] | |
| if is_office: | |
| page_images = [] | |
| session_state["page_images"] = [] | |
| else: | |
| cache_key = f"{file_hash}_{max_pages}" | |
| if cache_key in page_cache: | |
| page_images = page_cache[cache_key] | |
| else: | |
| page_images = load_pdf_pages(file_bytes, max_pages=max_pages) | |
| page_cache[cache_key] = page_images | |
| session_state["page_images"] = page_images | |
| status_lines.append(f"Number of pages rendered: {len(page_images)} (max {max_pages}).") | |
| if file_hash in parse_cache: | |
| parse_result = parse_cache[file_hash] | |
| else: | |
| parse_result = parse_document(file_bytes, file_ext=file_ext) | |
| parse_cache[file_hash] = parse_result | |
| session_state["parsed_result"] = parse_result | |
| status_lines.append("Document parsing done using Docling.") | |
| figures_info = extract_figures(page_images, parse_result.get("figures", [])) | |
| session_state["figures_info"] = figures_info | |
| status_lines.append(f"Number of figures extracted: {len(figures_info)}.") | |
| if figures_info: | |
| session_state["selected_figure"] = figures_info[0] | |
| fig_status = f"Figure 1 of {len(figures_info)} (Page {figures_info[0]['page'] + 1})" | |
| fig_caption = figures_info[0].get("caption", "No caption") | |
| fig_image = figures_info[0]["image"] | |
| else: | |
| session_state["selected_figure"] = None | |
| fig_status = "No figures found" | |
| fig_caption = "" | |
| fig_image = None | |
| html_content = parse_result.get("html", "No content available") | |
| status = "\n".join(status_lines) | |
| return status, html_content, fig_status, fig_caption, fig_image, session_state | |
| except Exception as e: # noqa: BLE001 | |
| import traceback | |
| print(f"Error: {e}") | |
| traceback.print_exc() | |
| return f"Error: {e!s}", f"Error loading document: {e!s}", "Error", "", None, session_state | |
| def _get_figure_display(session_state: dict[str, Any]) -> tuple[str, str, Image.Image | None]: | |
| """Return the current figure's display info, caption, and image. | |
| Args: | |
| session_state: Current session state dictionary. | |
| Returns: | |
| Tuple of (fig_status, fig_caption, fig_image). | |
| """ | |
| figures_info = session_state.get("figures_info", []) | |
| idx = session_state.get("current_figure_index", 0) | |
| if not figures_info: | |
| return "No figures found", "", None | |
| fig = figures_info[idx] | |
| fig_status = f"Figure {idx + 1} of {len(figures_info)} (Page {fig['page'] + 1})" | |
| fig_caption = fig.get("caption", "No caption") | |
| return fig_status, fig_caption, fig["image"] | |
| def next_figure(session_state: dict[str, Any]) -> tuple: | |
| """Advance to the next figure. | |
| Args: | |
| session_state: Current session state dictionary. | |
| Returns: | |
| Tuple of (fig_status, fig_caption, fig_image, session_state). | |
| """ | |
| figures_info = session_state.get("figures_info", []) | |
| if not figures_info: | |
| return "No figures found", "", None, session_state | |
| idx = (session_state.get("current_figure_index", 0) + 1) % len(figures_info) | |
| session_state["current_figure_index"] = idx | |
| session_state["selected_figure"] = figures_info[idx] | |
| session_state["conversation_history"] = [] | |
| session_state["current_image_path"] = None | |
| fig_status, fig_caption, fig_image = _get_figure_display(session_state) | |
| return fig_status, fig_caption, fig_image, session_state | |
| def prev_figure(session_state: dict[str, Any]) -> tuple: | |
| """Go back to the previous figure. | |
| Args: | |
| session_state: Current session state dictionary. | |
| Returns: | |
| Tuple of (fig_status, fig_caption, fig_image, session_state). | |
| """ | |
| figures_info = session_state.get("figures_info", []) | |
| if not figures_info: | |
| return "No figures found", "", None, session_state | |
| idx = (session_state.get("current_figure_index", 0) - 1) % len(figures_info) | |
| session_state["current_figure_index"] = idx | |
| session_state["selected_figure"] = figures_info[idx] | |
| session_state["conversation_history"] = [] | |
| session_state["current_image_path"] = None | |
| fig_status, fig_caption, fig_image = _get_figure_display(session_state) | |
| return fig_status, fig_caption, fig_image, session_state | |
| def describe_image_helper(session_state: dict[str, Any]): # noqa: ANN201 | |
| """Generate a detailed description of the selected figure (streaming).""" | |
| selected_fig = session_state.get("selected_figure") | |
| if selected_fig is None: | |
| yield "No figure selected", session_state | |
| return | |
| try: | |
| image = selected_fig["image"] | |
| for partial in answer_question_stream(image, "Describe this image in detail", [], None): | |
| yield partial, session_state | |
| except Exception as e: # noqa: BLE001 | |
| yield f"Error: {e!s}", session_state | |
| def load_current_figure(session_state: dict[str, Any]) -> tuple: | |
| """Load the current figure into display components (called on tab select). | |
| Also clears conversation history so each tab starts fresh. | |
| """ | |
| session_state["conversation_history"] = [] | |
| session_state["current_image_path"] = None | |
| fig_status, fig_caption, fig_image = _get_figure_display(session_state) | |
| return fig_status, fig_caption, fig_image, session_state | |
| PROMPT_TEXT_CODE = ( | |
| "Please take a look at this chart image and generate Python code that perfectly reconstructs this chart image." | |
| ) | |
| PROMPT_TEXT_SUMMARY = "<chart2summary>" | |
| PROMPT_TEXT_TABLE = "<tables_html>" | |
| def extract_code_helper(session_state: dict[str, Any]): # noqa: ANN201 | |
| """Generate Python code to reconstruct the selected chart (streaming).""" | |
| selected_fig = session_state.get("selected_figure") | |
| if selected_fig is None: | |
| yield "No figure selected", session_state | |
| return | |
| try: | |
| image = selected_fig["image"] | |
| for partial in answer_question_stream(image, PROMPT_TEXT_CODE, [], None): | |
| yield partial, session_state | |
| except Exception as e: # noqa: BLE001 | |
| yield f"Error: {e!s}", session_state | |
| def extract_summary_helper(session_state: dict[str, Any]): # noqa: ANN201 | |
| """Generate a text summary of the selected chart (streaming).""" | |
| selected_fig = session_state.get("selected_figure") | |
| if selected_fig is None: | |
| yield "No figure selected", session_state | |
| return | |
| try: | |
| image = selected_fig["image"] | |
| for partial in answer_question_stream(image, PROMPT_TEXT_SUMMARY, [], None): | |
| yield partial, session_state | |
| except Exception as e: # noqa: BLE001 | |
| yield f"Error: {e!s}", session_state | |
| def extract_table_helper(session_state: dict[str, Any]): # noqa: ANN201 | |
| """Extract tables as HTML from the selected figure (streaming).""" | |
| import re | |
| selected_fig = session_state.get("selected_figure") | |
| if selected_fig is None: | |
| yield "No figure selected", session_state | |
| return | |
| try: | |
| image = selected_fig["image"] | |
| accumulated = "" | |
| for partial in answer_question_stream(image, PROMPT_TEXT_TABLE, [], None): | |
| accumulated = partial | |
| yield accumulated, session_state | |
| # Final pass: strip markdown fences / brackets the model may wrap around HTML | |
| cleaned = re.sub(r"^```(?:html)?\s*", "", accumulated.strip()) | |
| cleaned = re.sub(r"\s*```$", "", cleaned.strip()) | |
| cleaned = re.sub(r"^\[\s*", "", cleaned.strip()) | |
| cleaned = re.sub(r"\s*\]$", "", cleaned.strip()) | |
| if cleaned != accumulated: | |
| yield cleaned, session_state | |
| except Exception as e: # noqa: BLE001 | |
| yield f"Error: {e!s}", session_state | |
| def extract_csv_helper(session_state: dict[str, Any]): # noqa: ANN201 | |
| """Extract CSV data from the selected chart (streaming).""" | |
| selected_fig = session_state.get("selected_figure") | |
| if selected_fig is None: | |
| yield "No figure selected", session_state | |
| return | |
| try: | |
| image = selected_fig["image"] | |
| for partial in extract_csv_stream(image): | |
| yield partial, session_state | |
| except Exception as e: # noqa: BLE001 | |
| yield f"Error: {e!s}", session_state | |
| def _make_nav(nav_fn: Any) -> Any: | |
| """Wrap a nav function to also clear the result panel when navigating figures.""" | |
| def _wrapper(session_state: dict[str, Any]) -> tuple: | |
| fig_status, fig_caption, fig_image, state = nav_fn(session_state) | |
| return fig_status, fig_caption, fig_image, "", state | |
| return _wrapper | |
| with gr.Blocks( | |
| title=TITLE, | |
| theme=theme, | |
| css_paths=css_file_path, | |
| head_paths=head_file_path, | |
| fill_height=True, | |
| ) as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(DESCRIPTION) | |
| session_state = gr.State(create_initial_state()) | |
| # Per-tab nav wrappers: clear result output when switching figures | |
| _sum_prev = _make_nav(prev_figure) | |
| _sum_next = _make_nav(next_figure) | |
| _csv_prev = _make_nav(prev_figure) | |
| _csv_next = _make_nav(next_figure) | |
| _code_prev = _make_nav(prev_figure) | |
| _code_next = _make_nav(next_figure) | |
| _tbl_prev = _make_nav(prev_figure) | |
| _tbl_next = _make_nav(next_figure) | |
| _qa_prev = _make_nav(prev_figure) | |
| _qa_next = _make_nav(next_figure) | |
| with gr.Tabs(): | |
| # TAB 1: UPLOAD & PARSE | |
| with gr.Tab("Parse & Extract"): | |
| file_path = gr.File( | |
| label="Upload PDF, Office Document, or Image", | |
| file_types=[".pdf", ".docx", ".xlsx", ".pptx", ".jpg", ".jpeg", ".jfif", ".png", ".bmp", ".dib", ".gif", ".tif", ".tiff", ".webp"], | |
| ) | |
| status = gr.Textbox(label="Status", interactive=False, lines=2) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| html_view = gr.Textbox( | |
| label="Parsed Document (Docling)", | |
| value="Upload a document to see parsed content", | |
| lines=35, | |
| interactive=False, | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Extracted Figures") | |
| fig_info = gr.Textbox(label="Figure Info", interactive=False) | |
| fig_caption = gr.Textbox(label="Caption", interactive=False) | |
| fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"]) | |
| with gr.Row(): | |
| prev_btn = gr.Button("Previous", scale=1) | |
| next_btn = gr.Button("Next", scale=1) | |
| file_path.upload( | |
| process_upload, | |
| inputs=[file_path, session_state], | |
| outputs=[status, html_view, fig_info, fig_caption, fig_image, session_state], | |
| ) | |
| next_btn.click( | |
| next_figure, | |
| inputs=[session_state], | |
| outputs=[fig_info, fig_caption, fig_image, session_state], | |
| ) | |
| prev_btn.click( | |
| prev_figure, | |
| inputs=[session_state], | |
| outputs=[fig_info, fig_caption, fig_image, session_state], | |
| ) | |
| # TAB 2: CHART2SUMMARY | |
| with gr.Tab("Chart2Summary") as summary_tab: | |
| gr.Markdown("Generate a text summary of the selected chart") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Figure") | |
| summary_fig_info = gr.Textbox(label="Figure Info", interactive=False) | |
| summary_fig_caption = gr.Textbox(label="Caption", interactive=False) | |
| summary_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"]) | |
| with gr.Row(): | |
| summary_prev_btn = gr.Button("Previous", scale=1) | |
| summary_next_btn = gr.Button("Next", scale=1) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Summary") | |
| summary_btn = gr.Button("Generate Summary", variant="primary") | |
| summary_out = gr.Textbox(label="Chart Summary", lines=20, interactive=False) | |
| summary_prev_btn.click(_sum_prev, inputs=[session_state], outputs=[summary_fig_info, summary_fig_caption, summary_fig_image, summary_out, session_state]) | |
| summary_next_btn.click(_sum_next, inputs=[session_state], outputs=[summary_fig_info, summary_fig_caption, summary_fig_image, summary_out, session_state]) | |
| summary_btn.click(extract_summary_helper, inputs=[session_state], outputs=[summary_out, session_state]) | |
| summary_tab.select(load_current_figure, inputs=[session_state], outputs=[summary_fig_info, summary_fig_caption, summary_fig_image, session_state]) | |
| # TAB 3: CHART2CSV | |
| with gr.Tab("Chart2CSV") as csv_tab: | |
| gr.Markdown("Extract CSV data from the selected chart") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Figure") | |
| csv_fig_info = gr.Textbox(label="Figure Info", interactive=False) | |
| csv_fig_caption = gr.Textbox(label="Caption", interactive=False) | |
| csv_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"]) | |
| with gr.Row(): | |
| csv_prev_btn = gr.Button("Previous", scale=1) | |
| csv_next_btn = gr.Button("Next", scale=1) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### CSV Extraction") | |
| extract_btn = gr.Button("Extract CSV", variant="primary") | |
| csv_out = gr.Textbox(label="CSV", lines=20, interactive=False) | |
| csv_prev_btn.click(_csv_prev, inputs=[session_state], outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, csv_out, session_state]) | |
| csv_next_btn.click(_csv_next, inputs=[session_state], outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, csv_out, session_state]) | |
| extract_btn.click(extract_csv_helper, inputs=[session_state], outputs=[csv_out, session_state]) | |
| csv_tab.select(load_current_figure, inputs=[session_state], outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, session_state]) | |
| # TAB 4: CHART2CODE | |
| with gr.Tab("Chart2Code") as code_tab: | |
| gr.Markdown("Generate Python code to reconstruct the selected chart") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Figure") | |
| code_fig_info = gr.Textbox(label="Figure Info", interactive=False) | |
| code_fig_caption = gr.Textbox(label="Caption", interactive=False) | |
| code_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"]) | |
| with gr.Row(): | |
| code_prev_btn = gr.Button("Previous", scale=1) | |
| code_next_btn = gr.Button("Next", scale=1) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Generated Code") | |
| code_btn = gr.Button("Generate Code", variant="primary") | |
| code_out = gr.Textbox(label="Python Code", lines=20, interactive=False) | |
| code_prev_btn.click(_code_prev, inputs=[session_state], outputs=[code_fig_info, code_fig_caption, code_fig_image, code_out, session_state]) | |
| code_next_btn.click(_code_next, inputs=[session_state], outputs=[code_fig_info, code_fig_caption, code_fig_image, code_out, session_state]) | |
| code_btn.click(extract_code_helper, inputs=[session_state], outputs=[code_out, session_state]) | |
| code_tab.select(load_current_figure, inputs=[session_state], outputs=[code_fig_info, code_fig_caption, code_fig_image, session_state]) | |
| # TAB 5: TABLE EXTRACTION | |
| with gr.Tab("Table Extraction") as table_tab: | |
| gr.Markdown("Extract table data as HTML from the selected figure") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Figure") | |
| table_fig_info = gr.Textbox(label="Figure Info", interactive=False) | |
| table_fig_caption = gr.Textbox(label="Caption", interactive=False) | |
| table_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"]) | |
| with gr.Row(): | |
| table_prev_btn = gr.Button("Previous", scale=1) | |
| table_next_btn = gr.Button("Next", scale=1) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Table Extraction") | |
| table_btn = gr.Button("Extract Table", variant="primary") | |
| table_out = gr.HTML(value="<p>Upload a document and click Extract Table to see results here</p>") | |
| table_prev_btn.click(_tbl_prev, inputs=[session_state], outputs=[table_fig_info, table_fig_caption, table_fig_image, table_out, session_state]) | |
| table_next_btn.click(_tbl_next, inputs=[session_state], outputs=[table_fig_info, table_fig_caption, table_fig_image, table_out, session_state]) | |
| table_btn.click(extract_table_helper, inputs=[session_state], outputs=[table_out, session_state]) | |
| table_tab.select(load_current_figure, inputs=[session_state], outputs=[table_fig_info, table_fig_caption, table_fig_image, session_state]) | |
| # TAB 6: IMAGE DESCRIPTION | |
| with gr.Tab("Image Description") as qa_tab: | |
| gr.Markdown("Get a detailed description of the selected figure") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Figure") | |
| qa_fig_info = gr.Textbox(label="Figure Info", interactive=False) | |
| qa_fig_caption = gr.Textbox(label="Caption", interactive=False) | |
| qa_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"]) | |
| with gr.Row(): | |
| qa_prev_btn = gr.Button("Previous", scale=1) | |
| qa_next_btn = gr.Button("Next", scale=1) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Description") | |
| describe_btn = gr.Button("Describe Image", variant="primary") | |
| answer = gr.Textbox(label="Description", lines=20, interactive=False) | |
| qa_prev_btn.click(_qa_prev, inputs=[session_state], outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, answer, session_state]) | |
| qa_next_btn.click(_qa_next, inputs=[session_state], outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, answer, session_state]) | |
| describe_btn.click(describe_image_helper, inputs=[session_state], outputs=[answer, session_state]) | |
| qa_tab.select(load_current_figure, inputs=[session_state], outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, session_state]) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) | |