Deploybot
Deploy from stable branch
49574d5
"""Granite Vision Document Intelligence Demo.
Upload a PDF or image to explore Granite-Vision-4.1-4B capabilities including
Chart2CSV, Chart2Code, Chart2Summary, Table Extraction, and Image Q&A.
"""
# Monkey-patch gradio_client to handle bool JSON Schema values.
# gradio 5.x emits additionalProperties: false/true (valid JSON Schema)
# but gradio_client 1.5.x does not guard against bool in get_type(),
# causing TypeError on every request to the /info endpoint.
try:
import gradio_client.utils as _gcu
_orig_get_type = _gcu.get_type
_orig_j2p = _gcu._json_schema_to_python_type
def _patched_get_type(schema): # noqa: ANN001, ANN202
if not isinstance(schema, dict):
return "unknown"
return _orig_get_type(schema)
def _patched_j2p(schema, defs=None): # noqa: ANN001, ANN202
if not isinstance(schema, dict):
return "any" if schema else "unknown"
return _orig_j2p(schema, defs)
_gcu.get_type = _patched_get_type
_gcu._json_schema_to_python_type = _patched_j2p
except Exception: # noqa: BLE001
pass
import os
from pathlib import Path
from typing import Any
import gradio as gr
from PIL import Image
from crops import extract_figures
from document_parser import parse_document
from infer_chart2csv import extract_csv_stream
from infer_vision_qa import answer_question_stream
from model_loader import load_processor
from pdf_io import load_pdf_pages
from themes.research_monochrome import theme
from ui_state import create_initial_state, hash_bytes, page_cache, parse_cache
# Pre-load the processor at startup (CPU-only, no GPU needed).
# This avoids paying the processor load cost on the first user request.
load_processor()
TITLE = "Granite Vision: Document Intelligence"
DESCRIPTION = (
"Upload a PDF, Word, Excel, PowerPoint, or image to explore Granite-Vision-4.1-4B's document intelligence capabilities — "
"including Chart2Summary, Chart2CSV, Chart2Code, Table Extraction, and Image Description — "
"with automatic Docling-powered parsing for documents and direct inference on uploaded images."
)
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".jfif", ".png", ".bmp", ".dib", ".gif", ".tif", ".tiff", ".webp"}
OFFICE_EXTENSIONS = {".docx", ".xlsx", ".pptx"}
css_file_path = Path(Path(__file__).parent / "app.css")
head_file_path = Path(Path(__file__).parent / "app_head.html")
def _is_image_file(file_path: str) -> bool:
"""Check whether a file path points to a supported image format."""
ext = os.path.splitext(file_path)[1].lower()
return ext in IMAGE_EXTENSIONS
def _is_office_file(file_path: str) -> bool:
"""Check whether a file path points to a supported Office format (DOCX/XLSX/PPTX)."""
ext = os.path.splitext(file_path)[1].lower()
return ext in OFFICE_EXTENSIONS
def process_upload(file_path: str, session_state: dict[str, Any]) -> tuple:
"""Parse an uploaded PDF or load an image and extract figures.
Args:
file_path: Path to the uploaded file.
session_state: Current Gradio session state dictionary.
Returns:
Tuple of (status, html_content, fig_status, fig_caption, fig_image, session_state).
"""
max_pages = 20
session_state["current_figure_index"] = 0
session_state["conversation_history"] = []
session_state["current_image_path"] = None
if not file_path:
return "Please upload a PDF, Office document, or image.", "No document loaded", "No figures", "", None, session_state
try:
with open(file_path, "rb") as f:
file_bytes = f.read()
file_hash = hash_bytes(file_bytes)
session_state["uploaded_file_hash"] = file_hash
session_state["uploaded_file_bytes"] = file_bytes
if _is_image_file(file_path):
image = Image.open(file_path).convert("RGB")
figures_info = [{"image": image, "page": 0, "bbox": None, "caption": ""}]
session_state["page_images"] = [image]
session_state["parsed_result"] = {}
session_state["figures_info"] = figures_info
session_state["selected_figure"] = figures_info[0]
return (
"Image loaded successfully.\nNumber of figures: 1.",
"Image uploaded directly (no document parsing needed)",
"Figure 1 of 1 (Page 1)",
"",
image,
session_state,
)
file_ext = os.path.splitext(file_path)[1].lower()
is_office = _is_office_file(file_path)
fmt_label = file_ext.lstrip(".").upper()
status_lines = [f"{fmt_label} loaded successfully."]
if is_office:
page_images = []
session_state["page_images"] = []
else:
cache_key = f"{file_hash}_{max_pages}"
if cache_key in page_cache:
page_images = page_cache[cache_key]
else:
page_images = load_pdf_pages(file_bytes, max_pages=max_pages)
page_cache[cache_key] = page_images
session_state["page_images"] = page_images
status_lines.append(f"Number of pages rendered: {len(page_images)} (max {max_pages}).")
if file_hash in parse_cache:
parse_result = parse_cache[file_hash]
else:
parse_result = parse_document(file_bytes, file_ext=file_ext)
parse_cache[file_hash] = parse_result
session_state["parsed_result"] = parse_result
status_lines.append("Document parsing done using Docling.")
figures_info = extract_figures(page_images, parse_result.get("figures", []))
session_state["figures_info"] = figures_info
status_lines.append(f"Number of figures extracted: {len(figures_info)}.")
if figures_info:
session_state["selected_figure"] = figures_info[0]
fig_status = f"Figure 1 of {len(figures_info)} (Page {figures_info[0]['page'] + 1})"
fig_caption = figures_info[0].get("caption", "No caption")
fig_image = figures_info[0]["image"]
else:
session_state["selected_figure"] = None
fig_status = "No figures found"
fig_caption = ""
fig_image = None
html_content = parse_result.get("html", "No content available")
status = "\n".join(status_lines)
return status, html_content, fig_status, fig_caption, fig_image, session_state
except Exception as e: # noqa: BLE001
import traceback
print(f"Error: {e}")
traceback.print_exc()
return f"Error: {e!s}", f"Error loading document: {e!s}", "Error", "", None, session_state
def _get_figure_display(session_state: dict[str, Any]) -> tuple[str, str, Image.Image | None]:
"""Return the current figure's display info, caption, and image.
Args:
session_state: Current session state dictionary.
Returns:
Tuple of (fig_status, fig_caption, fig_image).
"""
figures_info = session_state.get("figures_info", [])
idx = session_state.get("current_figure_index", 0)
if not figures_info:
return "No figures found", "", None
fig = figures_info[idx]
fig_status = f"Figure {idx + 1} of {len(figures_info)} (Page {fig['page'] + 1})"
fig_caption = fig.get("caption", "No caption")
return fig_status, fig_caption, fig["image"]
def next_figure(session_state: dict[str, Any]) -> tuple:
"""Advance to the next figure.
Args:
session_state: Current session state dictionary.
Returns:
Tuple of (fig_status, fig_caption, fig_image, session_state).
"""
figures_info = session_state.get("figures_info", [])
if not figures_info:
return "No figures found", "", None, session_state
idx = (session_state.get("current_figure_index", 0) + 1) % len(figures_info)
session_state["current_figure_index"] = idx
session_state["selected_figure"] = figures_info[idx]
session_state["conversation_history"] = []
session_state["current_image_path"] = None
fig_status, fig_caption, fig_image = _get_figure_display(session_state)
return fig_status, fig_caption, fig_image, session_state
def prev_figure(session_state: dict[str, Any]) -> tuple:
"""Go back to the previous figure.
Args:
session_state: Current session state dictionary.
Returns:
Tuple of (fig_status, fig_caption, fig_image, session_state).
"""
figures_info = session_state.get("figures_info", [])
if not figures_info:
return "No figures found", "", None, session_state
idx = (session_state.get("current_figure_index", 0) - 1) % len(figures_info)
session_state["current_figure_index"] = idx
session_state["selected_figure"] = figures_info[idx]
session_state["conversation_history"] = []
session_state["current_image_path"] = None
fig_status, fig_caption, fig_image = _get_figure_display(session_state)
return fig_status, fig_caption, fig_image, session_state
def describe_image_helper(session_state: dict[str, Any]): # noqa: ANN201
"""Generate a detailed description of the selected figure (streaming)."""
selected_fig = session_state.get("selected_figure")
if selected_fig is None:
yield "No figure selected", session_state
return
try:
image = selected_fig["image"]
for partial in answer_question_stream(image, "Describe this image in detail", [], None):
yield partial, session_state
except Exception as e: # noqa: BLE001
yield f"Error: {e!s}", session_state
def load_current_figure(session_state: dict[str, Any]) -> tuple:
"""Load the current figure into display components (called on tab select).
Also clears conversation history so each tab starts fresh.
"""
session_state["conversation_history"] = []
session_state["current_image_path"] = None
fig_status, fig_caption, fig_image = _get_figure_display(session_state)
return fig_status, fig_caption, fig_image, session_state
PROMPT_TEXT_CODE = (
"Please take a look at this chart image and generate Python code that perfectly reconstructs this chart image."
)
PROMPT_TEXT_SUMMARY = "<chart2summary>"
PROMPT_TEXT_TABLE = "<tables_html>"
def extract_code_helper(session_state: dict[str, Any]): # noqa: ANN201
"""Generate Python code to reconstruct the selected chart (streaming)."""
selected_fig = session_state.get("selected_figure")
if selected_fig is None:
yield "No figure selected", session_state
return
try:
image = selected_fig["image"]
for partial in answer_question_stream(image, PROMPT_TEXT_CODE, [], None):
yield partial, session_state
except Exception as e: # noqa: BLE001
yield f"Error: {e!s}", session_state
def extract_summary_helper(session_state: dict[str, Any]): # noqa: ANN201
"""Generate a text summary of the selected chart (streaming)."""
selected_fig = session_state.get("selected_figure")
if selected_fig is None:
yield "No figure selected", session_state
return
try:
image = selected_fig["image"]
for partial in answer_question_stream(image, PROMPT_TEXT_SUMMARY, [], None):
yield partial, session_state
except Exception as e: # noqa: BLE001
yield f"Error: {e!s}", session_state
def extract_table_helper(session_state: dict[str, Any]): # noqa: ANN201
"""Extract tables as HTML from the selected figure (streaming)."""
import re
selected_fig = session_state.get("selected_figure")
if selected_fig is None:
yield "No figure selected", session_state
return
try:
image = selected_fig["image"]
accumulated = ""
for partial in answer_question_stream(image, PROMPT_TEXT_TABLE, [], None):
accumulated = partial
yield accumulated, session_state
# Final pass: strip markdown fences / brackets the model may wrap around HTML
cleaned = re.sub(r"^```(?:html)?\s*", "", accumulated.strip())
cleaned = re.sub(r"\s*```$", "", cleaned.strip())
cleaned = re.sub(r"^\[\s*", "", cleaned.strip())
cleaned = re.sub(r"\s*\]$", "", cleaned.strip())
if cleaned != accumulated:
yield cleaned, session_state
except Exception as e: # noqa: BLE001
yield f"Error: {e!s}", session_state
def extract_csv_helper(session_state: dict[str, Any]): # noqa: ANN201
"""Extract CSV data from the selected chart (streaming)."""
selected_fig = session_state.get("selected_figure")
if selected_fig is None:
yield "No figure selected", session_state
return
try:
image = selected_fig["image"]
for partial in extract_csv_stream(image):
yield partial, session_state
except Exception as e: # noqa: BLE001
yield f"Error: {e!s}", session_state
def _make_nav(nav_fn: Any) -> Any:
"""Wrap a nav function to also clear the result panel when navigating figures."""
def _wrapper(session_state: dict[str, Any]) -> tuple:
fig_status, fig_caption, fig_image, state = nav_fn(session_state)
return fig_status, fig_caption, fig_image, "", state
return _wrapper
with gr.Blocks(
title=TITLE,
theme=theme,
css_paths=css_file_path,
head_paths=head_file_path,
fill_height=True,
) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESCRIPTION)
session_state = gr.State(create_initial_state())
# Per-tab nav wrappers: clear result output when switching figures
_sum_prev = _make_nav(prev_figure)
_sum_next = _make_nav(next_figure)
_csv_prev = _make_nav(prev_figure)
_csv_next = _make_nav(next_figure)
_code_prev = _make_nav(prev_figure)
_code_next = _make_nav(next_figure)
_tbl_prev = _make_nav(prev_figure)
_tbl_next = _make_nav(next_figure)
_qa_prev = _make_nav(prev_figure)
_qa_next = _make_nav(next_figure)
with gr.Tabs():
# TAB 1: UPLOAD & PARSE
with gr.Tab("Parse & Extract"):
file_path = gr.File(
label="Upload PDF, Office Document, or Image",
file_types=[".pdf", ".docx", ".xlsx", ".pptx", ".jpg", ".jpeg", ".jfif", ".png", ".bmp", ".dib", ".gif", ".tif", ".tiff", ".webp"],
)
status = gr.Textbox(label="Status", interactive=False, lines=2)
with gr.Row():
with gr.Column(scale=1):
html_view = gr.Textbox(
label="Parsed Document (Docling)",
value="Upload a document to see parsed content",
lines=35,
interactive=False,
)
with gr.Column(scale=1):
gr.Markdown("### Extracted Figures")
fig_info = gr.Textbox(label="Figure Info", interactive=False)
fig_caption = gr.Textbox(label="Caption", interactive=False)
fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"])
with gr.Row():
prev_btn = gr.Button("Previous", scale=1)
next_btn = gr.Button("Next", scale=1)
file_path.upload(
process_upload,
inputs=[file_path, session_state],
outputs=[status, html_view, fig_info, fig_caption, fig_image, session_state],
)
next_btn.click(
next_figure,
inputs=[session_state],
outputs=[fig_info, fig_caption, fig_image, session_state],
)
prev_btn.click(
prev_figure,
inputs=[session_state],
outputs=[fig_info, fig_caption, fig_image, session_state],
)
# TAB 2: CHART2SUMMARY
with gr.Tab("Chart2Summary") as summary_tab:
gr.Markdown("Generate a text summary of the selected chart")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Figure")
summary_fig_info = gr.Textbox(label="Figure Info", interactive=False)
summary_fig_caption = gr.Textbox(label="Caption", interactive=False)
summary_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"])
with gr.Row():
summary_prev_btn = gr.Button("Previous", scale=1)
summary_next_btn = gr.Button("Next", scale=1)
with gr.Column(scale=1):
gr.Markdown("### Summary")
summary_btn = gr.Button("Generate Summary", variant="primary")
summary_out = gr.Textbox(label="Chart Summary", lines=20, interactive=False)
summary_prev_btn.click(_sum_prev, inputs=[session_state], outputs=[summary_fig_info, summary_fig_caption, summary_fig_image, summary_out, session_state])
summary_next_btn.click(_sum_next, inputs=[session_state], outputs=[summary_fig_info, summary_fig_caption, summary_fig_image, summary_out, session_state])
summary_btn.click(extract_summary_helper, inputs=[session_state], outputs=[summary_out, session_state])
summary_tab.select(load_current_figure, inputs=[session_state], outputs=[summary_fig_info, summary_fig_caption, summary_fig_image, session_state])
# TAB 3: CHART2CSV
with gr.Tab("Chart2CSV") as csv_tab:
gr.Markdown("Extract CSV data from the selected chart")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Figure")
csv_fig_info = gr.Textbox(label="Figure Info", interactive=False)
csv_fig_caption = gr.Textbox(label="Caption", interactive=False)
csv_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"])
with gr.Row():
csv_prev_btn = gr.Button("Previous", scale=1)
csv_next_btn = gr.Button("Next", scale=1)
with gr.Column(scale=1):
gr.Markdown("### CSV Extraction")
extract_btn = gr.Button("Extract CSV", variant="primary")
csv_out = gr.Textbox(label="CSV", lines=20, interactive=False)
csv_prev_btn.click(_csv_prev, inputs=[session_state], outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, csv_out, session_state])
csv_next_btn.click(_csv_next, inputs=[session_state], outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, csv_out, session_state])
extract_btn.click(extract_csv_helper, inputs=[session_state], outputs=[csv_out, session_state])
csv_tab.select(load_current_figure, inputs=[session_state], outputs=[csv_fig_info, csv_fig_caption, csv_fig_image, session_state])
# TAB 4: CHART2CODE
with gr.Tab("Chart2Code") as code_tab:
gr.Markdown("Generate Python code to reconstruct the selected chart")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Figure")
code_fig_info = gr.Textbox(label="Figure Info", interactive=False)
code_fig_caption = gr.Textbox(label="Caption", interactive=False)
code_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"])
with gr.Row():
code_prev_btn = gr.Button("Previous", scale=1)
code_next_btn = gr.Button("Next", scale=1)
with gr.Column(scale=1):
gr.Markdown("### Generated Code")
code_btn = gr.Button("Generate Code", variant="primary")
code_out = gr.Textbox(label="Python Code", lines=20, interactive=False)
code_prev_btn.click(_code_prev, inputs=[session_state], outputs=[code_fig_info, code_fig_caption, code_fig_image, code_out, session_state])
code_next_btn.click(_code_next, inputs=[session_state], outputs=[code_fig_info, code_fig_caption, code_fig_image, code_out, session_state])
code_btn.click(extract_code_helper, inputs=[session_state], outputs=[code_out, session_state])
code_tab.select(load_current_figure, inputs=[session_state], outputs=[code_fig_info, code_fig_caption, code_fig_image, session_state])
# TAB 5: TABLE EXTRACTION
with gr.Tab("Table Extraction") as table_tab:
gr.Markdown("Extract table data as HTML from the selected figure")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Figure")
table_fig_info = gr.Textbox(label="Figure Info", interactive=False)
table_fig_caption = gr.Textbox(label="Caption", interactive=False)
table_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"])
with gr.Row():
table_prev_btn = gr.Button("Previous", scale=1)
table_next_btn = gr.Button("Next", scale=1)
with gr.Column(scale=1):
gr.Markdown("### Table Extraction")
table_btn = gr.Button("Extract Table", variant="primary")
table_out = gr.HTML(value="<p>Upload a document and click Extract Table to see results here</p>")
table_prev_btn.click(_tbl_prev, inputs=[session_state], outputs=[table_fig_info, table_fig_caption, table_fig_image, table_out, session_state])
table_next_btn.click(_tbl_next, inputs=[session_state], outputs=[table_fig_info, table_fig_caption, table_fig_image, table_out, session_state])
table_btn.click(extract_table_helper, inputs=[session_state], outputs=[table_out, session_state])
table_tab.select(load_current_figure, inputs=[session_state], outputs=[table_fig_info, table_fig_caption, table_fig_image, session_state])
# TAB 6: IMAGE DESCRIPTION
with gr.Tab("Image Description") as qa_tab:
gr.Markdown("Get a detailed description of the selected figure")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Figure")
qa_fig_info = gr.Textbox(label="Figure Info", interactive=False)
qa_fig_caption = gr.Textbox(label="Caption", interactive=False)
qa_fig_image = gr.Image(label="Figure", type="pil", elem_classes=["figure-image"])
with gr.Row():
qa_prev_btn = gr.Button("Previous", scale=1)
qa_next_btn = gr.Button("Next", scale=1)
with gr.Column(scale=1):
gr.Markdown("### Description")
describe_btn = gr.Button("Describe Image", variant="primary")
answer = gr.Textbox(label="Description", lines=20, interactive=False)
qa_prev_btn.click(_qa_prev, inputs=[session_state], outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, answer, session_state])
qa_next_btn.click(_qa_next, inputs=[session_state], outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, answer, session_state])
describe_btn.click(describe_image_helper, inputs=[session_state], outputs=[answer, session_state])
qa_tab.select(load_current_figure, inputs=[session_state], outputs=[qa_fig_info, qa_fig_caption, qa_fig_image, session_state])
if __name__ == "__main__":
demo.launch(ssr_mode=False)