"""Granite Vision Document Intelligence Demo. Upload a PDF or image to explore Granite-Vision-4.1-4B capabilities including Chart2CSV, Chart2Code, Chart2Summary, Table Extraction, and Image Q&A. """ # Monkey-patch gradio_client to handle bool JSON Schema values. # gradio 5.x emits additionalProperties: false/true (valid JSON Schema) # but gradio_client 1.5.x does not guard against bool in get_type(), # causing TypeError on every request to the /info endpoint. try: import gradio_client.utils as _gcu _orig_get_type = _gcu.get_type _orig_j2p = _gcu._json_schema_to_python_type def _patched_get_type(schema): # noqa: ANN001, ANN202 if not isinstance(schema, dict): return "unknown" return _orig_get_type(schema) def _patched_j2p(schema, defs=None): # noqa: ANN001, ANN202 if not isinstance(schema, dict): return "any" if schema else "unknown" return _orig_j2p(schema, defs) _gcu.get_type = _patched_get_type _gcu._json_schema_to_python_type = _patched_j2p except Exception: # noqa: BLE001 pass import os from pathlib import Path from typing import Any import gradio as gr from PIL import Image from crops import extract_figures from document_parser import parse_document from infer_chart2csv import extract_csv_stream from infer_vision_qa import answer_question_stream from model_loader import load_processor from pdf_io import load_pdf_pages from themes.research_monochrome import theme from ui_state import create_initial_state, hash_bytes, page_cache, parse_cache # Pre-load the processor at startup (CPU-only, no GPU needed). # This avoids paying the processor load cost on the first user request. load_processor() TITLE = "Granite Vision: Document Intelligence" DESCRIPTION = ( "Upload a PDF, Word, Excel, PowerPoint, or image to explore Granite-Vision-4.1-4B's document intelligence capabilities — " "including Chart2Summary, Chart2CSV, Chart2Code, Table Extraction, and Image Description — " "with automatic Docling-powered parsing for documents and direct inference on uploaded images." ) IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".jfif", ".png", ".bmp", ".dib", ".gif", ".tif", ".tiff", ".webp"} OFFICE_EXTENSIONS = {".docx", ".xlsx", ".pptx"} css_file_path = Path(Path(__file__).parent / "app.css") head_file_path = Path(Path(__file__).parent / "app_head.html") def _is_image_file(file_path: str) -> bool: """Check whether a file path points to a supported image format.""" ext = os.path.splitext(file_path)[1].lower() return ext in IMAGE_EXTENSIONS def _is_office_file(file_path: str) -> bool: """Check whether a file path points to a supported Office format (DOCX/XLSX/PPTX).""" ext = os.path.splitext(file_path)[1].lower() return ext in OFFICE_EXTENSIONS def process_upload(file_path: str, session_state: dict[str, Any]) -> tuple: """Parse an uploaded PDF or load an image and extract figures. Args: file_path: Path to the uploaded file. session_state: Current Gradio session state dictionary. Returns: Tuple of (status, html_content, fig_status, fig_caption, fig_image, session_state). """ max_pages = 20 session_state["current_figure_index"] = 0 session_state["conversation_history"] = [] session_state["current_image_path"] = None if not file_path: return "Please upload a PDF, Office document, or image.", "No document loaded", "No figures", "", None, session_state try: with open(file_path, "rb") as f: file_bytes = f.read() file_hash = hash_bytes(file_bytes) session_state["uploaded_file_hash"] = file_hash session_state["uploaded_file_bytes"] = file_bytes if _is_image_file(file_path): image = Image.open(file_path).convert("RGB") figures_info = [{"image": image, "page": 0, "bbox": None, "caption": ""}] session_state["page_images"] = [image] session_state["parsed_result"] = {} session_state["figures_info"] = figures_info session_state["selected_figure"] = figures_info[0] return ( "Image loaded successfully.\nNumber of figures: 1.", "Image uploaded directly (no document parsing needed)", "Figure 1 of 1 (Page 1)", "", image, session_state, ) file_ext = os.path.splitext(file_path)[1].lower() is_office = _is_office_file(file_path) fmt_label = file_ext.lstrip(".").upper() status_lines = [f"{fmt_label} loaded successfully."] if is_office: page_images = [] session_state["page_images"] = [] else: cache_key = f"{file_hash}_{max_pages}" if cache_key in page_cache: page_images = page_cache[cache_key] else: page_images = load_pdf_pages(file_bytes, max_pages=max_pages) page_cache[cache_key] = page_images session_state["page_images"] = page_images status_lines.append(f"Number of pages rendered: {len(page_images)} (max {max_pages}).") if file_hash in parse_cache: parse_result = parse_cache[file_hash] else: parse_result = parse_document(file_bytes, file_ext=file_ext) parse_cache[file_hash] = parse_result session_state["parsed_result"] = parse_result status_lines.append("Document parsing done using Docling.") figures_info = extract_figures(page_images, parse_result.get("figures", [])) session_state["figures_info"] = figures_info status_lines.append(f"Number of figures extracted: {len(figures_info)}.") if figures_info: session_state["selected_figure"] = figures_info[0] fig_status = f"Figure 1 of {len(figures_info)} (Page {figures_info[0]['page'] + 1})" fig_caption = figures_info[0].get("caption", "No caption") fig_image = figures_info[0]["image"] else: session_state["selected_figure"] = None fig_status = "No figures found" fig_caption = "" fig_image = None html_content = parse_result.get("html", "No content available") status = "\n".join(status_lines) return status, html_content, fig_status, fig_caption, fig_image, session_state except Exception as e: # noqa: BLE001 import traceback print(f"Error: {e}") traceback.print_exc() return f"Error: {e!s}", f"Error loading document: {e!s}", "Error", "", None, session_state def _get_figure_display(session_state: dict[str, Any]) -> tuple[str, str, Image.Image | None]: """Return the current figure's display info, caption, and image. Args: session_state: Current session state dictionary. Returns: Tuple of (fig_status, fig_caption, fig_image). """ figures_info = session_state.get("figures_info", []) idx = session_state.get("current_figure_index", 0) if not figures_info: return "No figures found", "", None fig = figures_info[idx] fig_status = f"Figure {idx + 1} of {len(figures_info)} (Page {fig['page'] + 1})" fig_caption = fig.get("caption", "No caption") return fig_status, fig_caption, fig["image"] def next_figure(session_state: dict[str, Any]) -> tuple: """Advance to the next figure. Args: session_state: Current session state dictionary. Returns: Tuple of (fig_status, fig_caption, fig_image, session_state). """ figures_info = session_state.get("figures_info", []) if not figures_info: return "No figures found", "", None, session_state idx = (session_state.get("current_figure_index", 0) + 1) % len(figures_info) session_state["current_figure_index"] = idx session_state["selected_figure"] = figures_info[idx] session_state["conversation_history"] = [] session_state["current_image_path"] = None fig_status, fig_caption, fig_image = _get_figure_display(session_state) return fig_status, fig_caption, fig_image, session_state def prev_figure(session_state: dict[str, Any]) -> tuple: """Go back to the previous figure. Args: session_state: Current session state dictionary. Returns: Tuple of (fig_status, fig_caption, fig_image, session_state). """ figures_info = session_state.get("figures_info", []) if not figures_info: return "No figures found", "", None, session_state idx = (session_state.get("current_figure_index", 0) - 1) % len(figures_info) session_state["current_figure_index"] = idx session_state["selected_figure"] = figures_info[idx] session_state["conversation_history"] = [] session_state["current_image_path"] = None fig_status, fig_caption, fig_image = _get_figure_display(session_state) return fig_status, fig_caption, fig_image, session_state def describe_image_helper(session_state: dict[str, Any]): # noqa: ANN201 """Generate a detailed description of the selected figure (streaming).""" selected_fig = session_state.get("selected_figure") if selected_fig is None: yield "No figure selected", session_state return try: image = selected_fig["image"] for partial in answer_question_stream(image, "Describe this image in detail", [], None): yield partial, session_state except Exception as e: # noqa: BLE001 yield f"Error: {e!s}", session_state def load_current_figure(session_state: dict[str, Any]) -> tuple: """Load the current figure into display components (called on tab select). Also clears conversation history so each tab starts fresh. """ session_state["conversation_history"] = [] session_state["current_image_path"] = None fig_status, fig_caption, fig_image = _get_figure_display(session_state) return fig_status, fig_caption, fig_image, session_state PROMPT_TEXT_CODE = ( "Please take a look at this chart image and generate Python code that perfectly reconstructs this chart image." ) PROMPT_TEXT_SUMMARY = "" PROMPT_TEXT_TABLE = "" def extract_code_helper(session_state: dict[str, Any]): # noqa: ANN201 """Generate Python code to reconstruct the selected chart (streaming).""" selected_fig = session_state.get("selected_figure") if selected_fig is None: yield "No figure selected", session_state return try: image = selected_fig["image"] for partial in answer_question_stream(image, PROMPT_TEXT_CODE, [], None): yield partial, session_state except Exception as e: # noqa: BLE001 yield f"Error: {e!s}", session_state def extract_summary_helper(session_state: dict[str, Any]): # noqa: ANN201 """Generate a text summary of the selected chart (streaming).""" selected_fig = session_state.get("selected_figure") if selected_fig is None: yield "No figure selected", session_state return try: image = selected_fig["image"] for partial in answer_question_stream(image, PROMPT_TEXT_SUMMARY, [], None): yield partial, session_state except Exception as e: # noqa: BLE001 yield f"Error: {e!s}", session_state def extract_table_helper(session_state: dict[str, Any]): # noqa: ANN201 """Extract tables as HTML from the selected figure (streaming).""" import re selected_fig = session_state.get("selected_figure") if selected_fig is None: yield "No figure selected", session_state return try: image = selected_fig["image"] accumulated = "" for partial in answer_question_stream(image, PROMPT_TEXT_TABLE, [], None): accumulated = partial yield accumulated, session_state # Final pass: strip markdown fences / brackets the model may wrap around HTML cleaned = re.sub(r"^```(?:html)?\s*", "", accumulated.strip()) cleaned = re.sub(r"\s*```$", "", cleaned.strip()) cleaned = re.sub(r"^\[\s*", "", cleaned.strip()) cleaned = re.sub(r"\s*\]$", "", cleaned.strip()) if cleaned != accumulated: yield cleaned, session_state except Exception as e: # noqa: BLE001 yield f"Error: {e!s}", session_state def extract_csv_helper(session_state: dict[str, Any]): # noqa: ANN201 """Extract CSV data from the selected chart (streaming).""" selected_fig = session_state.get("selected_figure") if selected_fig is None: yield "No figure selected", session_state return try: image = selected_fig["image"] for partial in extract_csv_stream(image): yield partial, session_state except Exception as e: # noqa: BLE001 yield f"Error: {e!s}", session_state def _make_nav(nav_fn: Any) -> Any: """Wrap a nav function to also clear the result panel when navigating figures.""" def _wrapper(session_state: dict[str, Any]) -> tuple: fig_status, fig_caption, fig_image, state = nav_fn(session_state) return fig_status, fig_caption, fig_image, "", state return _wrapper with gr.Blocks( title=TITLE, theme=theme, css_paths=css_file_path, head_paths=head_file_path, fill_height=True, ) as demo: gr.Markdown(f"# {TITLE}") gr.Markdown(DESCRIPTION) session_state = gr.State(create_initial_state()) # Per-tab nav wrappers: clear result output when switching figures _sum_prev = _make_nav(prev_figure) _sum_next = _make_nav(next_figure) _csv_prev = _make_nav(prev_figure) _csv_next = _make_nav(next_figure) _code_prev = _make_nav(prev_figure) _code_next = _make_nav(next_figure) _tbl_prev = _make_nav(prev_figure) _tbl_next = _make_nav(next_figure) _qa_prev = _make_nav(prev_figure) _qa_next = _make_nav(next_figure) with gr.Tabs(): # TAB 1: UPLOAD & PARSE with gr.Tab("Parse & Extract"): file_path = gr.File( label="Upload PDF, Office Document, or Image", file_types=[".pdf", ".docx", ".xlsx", ".pptx", ".jpg", ".jpeg", ".jfif", ".png", ".bmp", ".dib", ".gif", ".tif", ".tiff", ".webp"], ) status = gr.Textbox(label="Status", interactive=False, lines=2) with gr.Row(): with gr.Column(scale=1): html_view = gr.Textbox( label="Parsed Document (Docling)", value="Upload a document to see parsed content", lines=35, interactive=False, ) with gr.Column(scale=1): gr.Markdown("### Extracted Figures") fig_info = gr.Textbox(label="Figure Info", interactive=False) fig_caption = gr.Textbox(label="Caption", interactive=False) fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"]) with gr.Row(): prev_btn = gr.Button("Previous", scale=1) next_btn = gr.Button("Next", scale=1) file_path.upload( process_upload, inputs=[file_path, session_state], outputs=[status, html_view, fig_info, fig_caption, fig_image, session_state], ) next_btn.click( next_figure, inputs=[session_state], outputs=[fig_info, fig_caption, fig_image, session_state], ) prev_btn.click( prev_figure, inputs=[session_state], outputs=[fig_info, fig_caption, fig_image, session_state], ) # TAB 2: CHART2SUMMARY with gr.Tab("Chart2Summary") as summary_tab: gr.Markdown("Generate a text summary of the selected chart") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Figure") summary_fig_info = gr.Textbox(label="Figure Info", interactive=False) summary_fig_caption = gr.Textbox(label="Caption", interactive=False) summary_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"]) with gr.Row(): summary_prev_btn = gr.Button("Previous", scale=1) summary_next_btn = gr.Button("Next", scale=1) with gr.Column(scale=1): gr.Markdown("### Summary") summary_btn = gr.Button("Generate Summary", variant="primary") summary_out = gr.Textbox(label="Chart Summary", lines=20, interactive=False) summary_prev_btn.click(_sum_prev, inputs=[session_state], outputs=[summary_fig_info, summary_fig_caption, summary_fig_image, summary_out, session_state]) summary_next_btn.click(_sum_next, inputs=[session_state], outputs=[summary_fig_info, summary_fig_caption, summary_fig_image, summary_out, session_state]) summary_btn.click(extract_summary_helper, inputs=[session_state], outputs=[summary_out, session_state]) summary_tab.select(load_current_figure, inputs=[session_state], outputs=[summary_fig_info, summary_fig_caption, summary_fig_image, session_state]) # TAB 3: CHART2CSV with gr.Tab("Chart2CSV") as csv_tab: gr.Markdown("Extract CSV data from the selected chart") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Figure") csv_fig_info = gr.Textbox(label="Figure Info", interactive=False) csv_fig_caption = gr.Textbox(label="Caption", interactive=False) csv_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"]) with gr.Row(): csv_prev_btn = gr.Button("Previous", scale=1) csv_next_btn = gr.Button("Next", scale=1) with gr.Column(scale=1): gr.Markdown("### CSV Extraction") extract_btn = gr.Button("Extract CSV", variant="primary") csv_out = gr.Textbox(label="CSV", lines=20, interactive=False) csv_prev_btn.click(_csv_prev, inputs=[session_state], outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, csv_out, session_state]) csv_next_btn.click(_csv_next, inputs=[session_state], outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, csv_out, session_state]) extract_btn.click(extract_csv_helper, inputs=[session_state], outputs=[csv_out, session_state]) csv_tab.select(load_current_figure, inputs=[session_state], outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, session_state]) # TAB 4: CHART2CODE with gr.Tab("Chart2Code") as code_tab: gr.Markdown("Generate Python code to reconstruct the selected chart") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Figure") code_fig_info = gr.Textbox(label="Figure Info", interactive=False) code_fig_caption = gr.Textbox(label="Caption", interactive=False) code_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"]) with gr.Row(): code_prev_btn = gr.Button("Previous", scale=1) code_next_btn = gr.Button("Next", scale=1) with gr.Column(scale=1): gr.Markdown("### Generated Code") code_btn = gr.Button("Generate Code", variant="primary") code_out = gr.Textbox(label="Python Code", lines=20, interactive=False) code_prev_btn.click(_code_prev, inputs=[session_state], outputs=[code_fig_info, code_fig_caption, code_fig_image, code_out, session_state]) code_next_btn.click(_code_next, inputs=[session_state], outputs=[code_fig_info, code_fig_caption, code_fig_image, code_out, session_state]) code_btn.click(extract_code_helper, inputs=[session_state], outputs=[code_out, session_state]) code_tab.select(load_current_figure, inputs=[session_state], outputs=[code_fig_info, code_fig_caption, code_fig_image, session_state]) # TAB 5: TABLE EXTRACTION with gr.Tab("Table Extraction") as table_tab: gr.Markdown("Extract table data as HTML from the selected figure") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Figure") table_fig_info = gr.Textbox(label="Figure Info", interactive=False) table_fig_caption = gr.Textbox(label="Caption", interactive=False) table_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"]) with gr.Row(): table_prev_btn = gr.Button("Previous", scale=1) table_next_btn = gr.Button("Next", scale=1) with gr.Column(scale=1): gr.Markdown("### Table Extraction") table_btn = gr.Button("Extract Table", variant="primary") table_out = gr.HTML(value="

Upload a document and click Extract Table to see results here

") table_prev_btn.click(_tbl_prev, inputs=[session_state], outputs=[table_fig_info, table_fig_caption, table_fig_image, table_out, session_state]) table_next_btn.click(_tbl_next, inputs=[session_state], outputs=[table_fig_info, table_fig_caption, table_fig_image, table_out, session_state]) table_btn.click(extract_table_helper, inputs=[session_state], outputs=[table_out, session_state]) table_tab.select(load_current_figure, inputs=[session_state], outputs=[table_fig_info, table_fig_caption, table_fig_image, session_state]) # TAB 6: IMAGE DESCRIPTION with gr.Tab("Image Description") as qa_tab: gr.Markdown("Get a detailed description of the selected figure") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Figure") qa_fig_info = gr.Textbox(label="Figure Info", interactive=False) qa_fig_caption = gr.Textbox(label="Caption", interactive=False) qa_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"]) with gr.Row(): qa_prev_btn = gr.Button("Previous", scale=1) qa_next_btn = gr.Button("Next", scale=1) with gr.Column(scale=1): gr.Markdown("### Description") describe_btn = gr.Button("Describe Image", variant="primary") answer = gr.Textbox(label="Description", lines=20, interactive=False) qa_prev_btn.click(_qa_prev, inputs=[session_state], outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, answer, session_state]) qa_next_btn.click(_qa_next, inputs=[session_state], outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, answer, session_state]) describe_btn.click(describe_image_helper, inputs=[session_state], outputs=[answer, session_state]) qa_tab.select(load_current_figure, inputs=[session_state], outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, session_state]) if __name__ == "__main__": demo.launch(ssr_mode=False)