document-ocr / app.py
Luis J Camargo
fix: decrease layout threshold from 0.75 to 0.1
7423198
import os
import io
import json
import base64
import re
import logging
import sys
import yaml
import traceback
import subprocess
from typing import Dict, List, Tuple, Any, Optional
import time
import gradio as gr
from PIL import Image, ImageDraw
import requests
from urllib.parse import urlparse
from huggingface_hub import snapshot_download
# --- Configuration ---
LOGGING_FORMAT = '%(asctime)s [%(levelname)s] %(name)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=LOGGING_FORMAT, handlers=[logging.StreamHandler(sys.stdout)])
logger = logging.getLogger("TachiwinDocOCR")
REPO_ID = "tachiwin/Tachiwin-OCR-1.5"
OUTPUT_DIR = "output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
LATEX_DELIMS = [
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
{"left": "\\(", "right": "\\)", "display": False},
{"left": "\\[", "right": "\\]", "display": True},
]
# --- Paddle OCR-VL Imports ---
PADDLE_AVAILABLE = False
try:
from paddleocr import PaddleOCRVL
PADDLE_AVAILABLE = True
logger.info("βœ… PaddleOCRVL imported successfully from paddleocr.")
except ImportError:
try:
from paddlex.inference.pipelines.paddle_ocr_vl import PaddleOCRVL
PADDLE_AVAILABLE = True
logger.info("βœ… PaddleOCRVL imported successfully from paddlex.")
except ImportError as e:
logger.error(f"❌ Failed to import PaddleOCRVL: {e}")
# --- Model Initialization ---
pipeline = None
def setup_pipeline():
global pipeline
if not PADDLE_AVAILABLE:
logger.error("Skipping pipeline setup because PaddleOCRVL is not available.")
return
try:
logger.info("πŸš€ Starting Tachiwin Doc OCR Pipeline Setup...")
# 1. Download Model from Hugging Face Hub
logger.info(f"πŸ“¦ Downloading custom model from HF: {REPO_ID}...")
local_model_path = snapshot_download(repo_id=REPO_ID)
logger.info(f"βœ… Model downloaded to: {local_model_path}")
# 2. Instantiate PaddleOCRVL directly as per user docs
logger.info("βš™οΈ Initializing PaddleOCRVL pipeline instance...")
pipeline = PaddleOCRVL(
pipeline_version="v1.5",
vl_rec_model_name="PaddleOCR-VL-1.5-0.9B",
vl_rec_model_dir=local_model_path,
device="cpu",
layout_threshold=0.1,
enable_mkldnn=True,
use_queues=True,
)
logger.info("✨ Pipeline instance created successfully!")
except Exception as e:
logger.error("πŸ”₯ CRITICAL: Pipeline Setup Failed")
logger.error(traceback.format_exc())
if PADDLE_AVAILABLE:
setup_pipeline()
# --- Helper Functions ---
def image_to_base64_data_url(filepath: str) -> str:
try:
ext = os.path.splitext(filepath)[1].lower()
mime_types = {
".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png",
".gif": "image/gif", ".webp": "image/webp", ".bmp": "image/bmp"
}
mime_type = mime_types.get(ext, "image/jpeg")
with open(filepath, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return f"data:{mime_type};base64,{encoded_string}"
except Exception as e:
logger.error(f"Error encoding image to Base64: {e}")
return ""
def _escape_inequalities_in_math(md: str) -> str:
if not md:
return ""
if "$" not in md and "\\[" not in md and "\\(" not in md:
return md
_MATH_PATTERNS = [
re.compile(r"\$$([\s\S]+?)\$$"),
re.compile(r"\$([^\$]+?)\$"),
re.compile(r"\\\[([\s\S]+?)\\\]"),
re.compile(r"\\\(([\s\S]+?)\\\)"),
]
def fix(s: str) -> str:
s = s.replace("<=", r" \le ").replace(">=", r" \ge ")
s = s.replace("≀", r" \le ").replace("β‰₯", r" \ge ")
s = s.replace("<", r" \lt ").replace(">", r" \gt ")
return s
for pat in _MATH_PATTERNS:
md = pat.sub(lambda m: m.group(0).replace(m.group(1), fix(m.group(1))), md)
return md
def draw_layout_map(img_path, json_content):
try:
if not json_content or json_content == "{}":
return None
data = json.loads(json_content)
img = Image.open(img_path).convert("RGB")
draw = ImageDraw.Draw(img)
# Simple color map for labels
colors = {
"title": "#ef4444", "header": "#10b981", "footer": "#10b981",
"figure": "#f59e0b", "table": "#8b5cf6", "text": "#3b82f6",
"equation": "#06b6d4"
}
res_list = data.get("parsing_res_list", [])
for block in res_list:
bbox = block.get("block_bbox")
label = block.get("block_label", "text").lower()
if bbox and len(bbox) == 4:
color = colors.get(label, "#3b82f6")
draw.rectangle(bbox, outline=color, width=4)
buf = io.BytesIO()
img.save(buf, format="PNG")
return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
except Exception as e:
logger.error(f"Error drawing layout map: {e}")
return None
# --- Inference Logic ---
def run_inference(img_path, mode, use_chart=False, use_unwarping=False, progress=gr.Progress()):
if not PADDLE_AVAILABLE or pipeline is None:
yield "❌ Paddle backend not available.", "", "", "{}", ""
return
if not img_path:
yield "⚠️ No image provided.", "", "", "{}", ""
return
try:
logger.info(f"πŸš€ Inference Start | Mode: {mode} | Image: {img_path}")
progress(0, desc="Running core prediction...")
use_layout = True
prompt_label = None
if mode == "Formula Recognition":
use_layout = False
prompt_label = "formula"
elif mode == "Table Recognition":
use_layout = False
prompt_label = "table"
elif mode == "Text Recognition":
use_layout = False
prompt_label = "text"
elif mode == "Feature Spotting":
use_layout = False
prompt_label = "spotting"
predict_params = {
"input": img_path,
"use_layout_detection": use_layout,
"prompt_label": prompt_label,
"use_chart_recognition": use_chart,
"use_doc_unwarping": use_unwarping,
"use_doc_orientation_classify": use_unwarping,
"use_queues": False
}
output_iter = pipeline.predict(**predict_params)
pages_res = []
md_content = ""
json_content = ""
vis_html = ""
run_id = f"run_{int(time.time())}"
run_output_dir = os.path.join(OUTPUT_DIR, run_id)
os.makedirs(run_output_dir, exist_ok=True)
for i, res in enumerate(output_iter):
logger.info(f"Received segment/page {i+1}...")
progress(None, desc=f"Analyzing segment {i+1}...")
pages_res.append(res)
seg_dir = os.path.join(run_output_dir, f"seg_{i}")
os.makedirs(seg_dir, exist_ok=True)
res.save_to_json(save_path=seg_dir)
res.save_to_markdown(save_path=seg_dir)
res.save_to_img(save_path=seg_dir)
fnames = os.listdir(seg_dir)
for fname in fnames:
fpath = os.path.join(seg_dir, fname)
if fname.endswith((".png", ".jpg", ".jpeg")) and ("res" in fname or "vis" in fname):
vis_src = image_to_base64_data_url(fpath)
new_vis = f'<div style="margin-bottom:20px; border: 2px solid #10b981; border-radius: 12px; overflow: hidden; background:white;">'
new_vis += f'<img src="{vis_src}" alt="Page {i+1}" style="width:100%;"></div>'
if new_vis not in vis_html:
vis_html += new_vis
yield "βŒ› Reconstructing structure...", vis_html, "βŒ› Reconstructing...", "{}", ""
logger.info("πŸ”„ Restructuring pages for final output...")
progress(0.9, desc="Restructuring pages...")
reconstructed_output = pipeline.restructure_pages(
pages_res,
merge_tables=True,
relevel_titles=True,
concatenate_pages=True
)
for i, res in enumerate(reconstructed_output):
res.save_to_json(save_path=run_output_dir)
res.save_to_markdown(save_path=run_output_dir)
fnames = os.listdir(run_output_dir)
for fname in fnames:
fpath = os.path.join(run_output_dir, fname)
if fname.endswith(".md"):
with open(fpath, 'r', encoding='utf-8') as f:
c = f.read()
if c not in md_content: md_content += c + "\n\n"
elif fname.endswith(".json"):
with open(fpath, 'r', encoding='utf-8') as f:
c = f.read()
if c not in json_content: json_content += c + "\n\n"
final_md = _escape_inequalities_in_math(md_content) or "⚠️ No text recognized."
layout_map_src = draw_layout_map(img_path, json_content)
layout_html = ""
if layout_map_src:
layout_html = f'<div style="border: 2px solid #3b82f6; border-radius: 12px; overflow: hidden;"><img src="{layout_map_src}" style="width:100%;"></div>'
progress(1.0, desc="Complete")
yield final_md, vis_html, md_content, json_content, layout_html
logger.info("--- βœ… Inference and Restructuring Finished ---")
except Exception as e:
logger.error(f"❌ Inference Error: {e}")
logger.error(traceback.format_exc())
yield f"❌ Error: {str(e)}", "", "", "{}", ""
# --- UI Layout ---
custom_css = """
body, .gradio-container { font-family: 'Inter', system-ui, sans-serif; }
.app-header {
text-align: center;
padding: 2.5rem;
background: linear-gradient(135deg, #0284c7 0%, #10b981 100%);
color: white;
border-radius: 1.5rem;
margin-bottom: 2rem;
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1);
}
.app-header h1 { color: white !important; font-weight: 800; font-size: 2.5rem; }
.status-card {
background-color: #0d9488 !important;
color: white !important;
padding: 1.25rem;
border-radius: 1rem;
margin-bottom: 2.5rem;
font-weight: 600;
text-align: center;
box-shadow: inset 0 2px 4px 0 rgba(0, 0, 0, 0.06);
}
.status-card p { margin: 0; color: white !important; }
.output-box { border: 1px solid #e2e8f0 !important; border-radius: 1rem !important; }
"""
with gr.Blocks(theme=gr.themes.Ocean(), css=custom_css) as demo:
gr.HTML(
"""
<div class="app-header">
<h1>🌎 Tachiwin Document Parsing OCR 🦑</h1>
<p>Advancing linguistic rights with state-of-the-art document parsing</p>
</div>
"""
)
status_msg = "Pipeline Ready" if pipeline else "Initialization in progress..."
gr.HTML(
f"""
<div class="status-card">
<p><strong>⚑ Status:</strong> {status_msg} &nbsp;|&nbsp;
<strong>Model:</strong> <code>{REPO_ID}</code> &nbsp;|&nbsp;
<strong>Hardware:</strong> CPU</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=5):
img_input = gr.Image(label="Upload Image", type="filepath")
mode_selector = gr.Dropdown(
choices=["Full Document Parsing", "Formula Recognition", "Table Recognition", "Text Recognition", "Feature Spotting"],
value="Full Document Parsing",
label="Analysis Mode"
)
with gr.Accordion("Advanced Options", open=False):
chart_check = gr.Checkbox(label="Chart Recognition", value=True)
unwarp_check = gr.Checkbox(label="Document Unwarping / Orientation Fix", value=False)
btn_run = gr.Button("πŸš€ Start Analysis", variant="primary")
with gr.Column(scale=7):
with gr.Tabs():
with gr.Tab("πŸ“ Markdown View"):
md_view = gr.Markdown(latex_delimiters=LATEX_DELIMS, elem_classes="output-box")
with gr.Tab("πŸ–ΌοΈ Visual Results"):
vis_view = gr.HTML('<div style="text-align:center; color:#94a3b8; padding: 50px;">Results will appear here.</div>')
with gr.Tab("πŸ“œ Raw Source"):
raw_view = gr.Code(language="markdown")
with gr.Tab("πŸ’Ύ JSON Feed"):
json_view = gr.Code(language="json")
with gr.Tab("πŸ—ΊοΈ Layout Map"):
layout_view = gr.HTML('<div style="text-align:center; color:#94a3b8; padding: 50px;">Bboxes view will appear here.</div>')
def unified_wrapper(img, mode, chart, unwarp, progress=gr.Progress()):
if not img:
yield "⚠️ Please upload an image.", "", "", "{}", ""
return
yield "βŒ› Running prediction...", gr.update(value="<div style='text-align:center;'>βŒ› Processing...</div>"), "βŒ› Processing...", "{}", gr.update(value="<div style='text-align:center;'>βŒ› Generating map...</div>")
for res_preview, res_vis, res_raw, res_json, res_map in run_inference(img, mode, chart, unwarp, progress=progress):
yield res_preview, res_vis, res_raw, res_json, res_map
btn_run.click(
unified_wrapper,
[img_input, mode_selector, chart_check, unwarp_check],
[md_view, vis_view, raw_view, json_view, layout_view],
show_progress="full"
)
gr.Markdown("--- \n *Tachiwin Project: Indigenous Languages of Mexico.*")
if __name__ == "__main__":
demo.queue().launch()