Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import json | |
| import base64 | |
| import re | |
| import logging | |
| import sys | |
| import yaml | |
| import traceback | |
| import subprocess | |
| from typing import Dict, List, Tuple, Any, Optional | |
| import time | |
| import gradio as gr | |
| from PIL import Image, ImageDraw | |
| import requests | |
| from urllib.parse import urlparse | |
| from huggingface_hub import snapshot_download | |
| # --- Configuration --- | |
| LOGGING_FORMAT = '%(asctime)s [%(levelname)s] %(name)s: %(message)s' | |
| logging.basicConfig(level=logging.INFO, format=LOGGING_FORMAT, handlers=[logging.StreamHandler(sys.stdout)]) | |
| logger = logging.getLogger("TachiwinDocOCR") | |
| REPO_ID = "tachiwin/Tachiwin-OCR-1.5" | |
| OUTPUT_DIR = "output" | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| LATEX_DELIMS = [ | |
| {"left": "$$", "right": "$$", "display": True}, | |
| {"left": "$", "right": "$", "display": False}, | |
| {"left": "\\(", "right": "\\)", "display": False}, | |
| {"left": "\\[", "right": "\\]", "display": True}, | |
| ] | |
| # --- Paddle OCR-VL Imports --- | |
| PADDLE_AVAILABLE = False | |
| try: | |
| from paddleocr import PaddleOCRVL | |
| PADDLE_AVAILABLE = True | |
| logger.info("β PaddleOCRVL imported successfully from paddleocr.") | |
| except ImportError: | |
| try: | |
| from paddlex.inference.pipelines.paddle_ocr_vl import PaddleOCRVL | |
| PADDLE_AVAILABLE = True | |
| logger.info("β PaddleOCRVL imported successfully from paddlex.") | |
| except ImportError as e: | |
| logger.error(f"β Failed to import PaddleOCRVL: {e}") | |
| # --- Model Initialization --- | |
| pipeline = None | |
| def setup_pipeline(): | |
| global pipeline | |
| if not PADDLE_AVAILABLE: | |
| logger.error("Skipping pipeline setup because PaddleOCRVL is not available.") | |
| return | |
| try: | |
| logger.info("π Starting Tachiwin Doc OCR Pipeline Setup...") | |
| # 1. Download Model from Hugging Face Hub | |
| logger.info(f"π¦ Downloading custom model from HF: {REPO_ID}...") | |
| local_model_path = snapshot_download(repo_id=REPO_ID) | |
| logger.info(f"β Model downloaded to: {local_model_path}") | |
| # 2. Instantiate PaddleOCRVL directly as per user docs | |
| logger.info("βοΈ Initializing PaddleOCRVL pipeline instance...") | |
| pipeline = PaddleOCRVL( | |
| pipeline_version="v1.5", | |
| vl_rec_model_name="PaddleOCR-VL-1.5-0.9B", | |
| vl_rec_model_dir=local_model_path, | |
| device="cpu", | |
| layout_threshold=0.1, | |
| enable_mkldnn=True, | |
| use_queues=True, | |
| ) | |
| logger.info("β¨ Pipeline instance created successfully!") | |
| except Exception as e: | |
| logger.error("π₯ CRITICAL: Pipeline Setup Failed") | |
| logger.error(traceback.format_exc()) | |
| if PADDLE_AVAILABLE: | |
| setup_pipeline() | |
| # --- Helper Functions --- | |
| def image_to_base64_data_url(filepath: str) -> str: | |
| try: | |
| ext = os.path.splitext(filepath)[1].lower() | |
| mime_types = { | |
| ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png", | |
| ".gif": "image/gif", ".webp": "image/webp", ".bmp": "image/bmp" | |
| } | |
| mime_type = mime_types.get(ext, "image/jpeg") | |
| with open(filepath, "rb") as image_file: | |
| encoded_string = base64.b64encode(image_file.read()).decode("utf-8") | |
| return f"data:{mime_type};base64,{encoded_string}" | |
| except Exception as e: | |
| logger.error(f"Error encoding image to Base64: {e}") | |
| return "" | |
| def _escape_inequalities_in_math(md: str) -> str: | |
| if not md: | |
| return "" | |
| if "$" not in md and "\\[" not in md and "\\(" not in md: | |
| return md | |
| _MATH_PATTERNS = [ | |
| re.compile(r"\$$([\s\S]+?)\$$"), | |
| re.compile(r"\$([^\$]+?)\$"), | |
| re.compile(r"\\\[([\s\S]+?)\\\]"), | |
| re.compile(r"\\\(([\s\S]+?)\\\)"), | |
| ] | |
| def fix(s: str) -> str: | |
| s = s.replace("<=", r" \le ").replace(">=", r" \ge ") | |
| s = s.replace("β€", r" \le ").replace("β₯", r" \ge ") | |
| s = s.replace("<", r" \lt ").replace(">", r" \gt ") | |
| return s | |
| for pat in _MATH_PATTERNS: | |
| md = pat.sub(lambda m: m.group(0).replace(m.group(1), fix(m.group(1))), md) | |
| return md | |
| def draw_layout_map(img_path, json_content): | |
| try: | |
| if not json_content or json_content == "{}": | |
| return None | |
| data = json.loads(json_content) | |
| img = Image.open(img_path).convert("RGB") | |
| draw = ImageDraw.Draw(img) | |
| # Simple color map for labels | |
| colors = { | |
| "title": "#ef4444", "header": "#10b981", "footer": "#10b981", | |
| "figure": "#f59e0b", "table": "#8b5cf6", "text": "#3b82f6", | |
| "equation": "#06b6d4" | |
| } | |
| res_list = data.get("parsing_res_list", []) | |
| for block in res_list: | |
| bbox = block.get("block_bbox") | |
| label = block.get("block_label", "text").lower() | |
| if bbox and len(bbox) == 4: | |
| color = colors.get(label, "#3b82f6") | |
| draw.rectangle(bbox, outline=color, width=4) | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" | |
| except Exception as e: | |
| logger.error(f"Error drawing layout map: {e}") | |
| return None | |
| # --- Inference Logic --- | |
| def run_inference(img_path, mode, use_chart=False, use_unwarping=False, progress=gr.Progress()): | |
| if not PADDLE_AVAILABLE or pipeline is None: | |
| yield "β Paddle backend not available.", "", "", "{}", "" | |
| return | |
| if not img_path: | |
| yield "β οΈ No image provided.", "", "", "{}", "" | |
| return | |
| try: | |
| logger.info(f"π Inference Start | Mode: {mode} | Image: {img_path}") | |
| progress(0, desc="Running core prediction...") | |
| use_layout = True | |
| prompt_label = None | |
| if mode == "Formula Recognition": | |
| use_layout = False | |
| prompt_label = "formula" | |
| elif mode == "Table Recognition": | |
| use_layout = False | |
| prompt_label = "table" | |
| elif mode == "Text Recognition": | |
| use_layout = False | |
| prompt_label = "text" | |
| elif mode == "Feature Spotting": | |
| use_layout = False | |
| prompt_label = "spotting" | |
| predict_params = { | |
| "input": img_path, | |
| "use_layout_detection": use_layout, | |
| "prompt_label": prompt_label, | |
| "use_chart_recognition": use_chart, | |
| "use_doc_unwarping": use_unwarping, | |
| "use_doc_orientation_classify": use_unwarping, | |
| "use_queues": False | |
| } | |
| output_iter = pipeline.predict(**predict_params) | |
| pages_res = [] | |
| md_content = "" | |
| json_content = "" | |
| vis_html = "" | |
| run_id = f"run_{int(time.time())}" | |
| run_output_dir = os.path.join(OUTPUT_DIR, run_id) | |
| os.makedirs(run_output_dir, exist_ok=True) | |
| for i, res in enumerate(output_iter): | |
| logger.info(f"Received segment/page {i+1}...") | |
| progress(None, desc=f"Analyzing segment {i+1}...") | |
| pages_res.append(res) | |
| seg_dir = os.path.join(run_output_dir, f"seg_{i}") | |
| os.makedirs(seg_dir, exist_ok=True) | |
| res.save_to_json(save_path=seg_dir) | |
| res.save_to_markdown(save_path=seg_dir) | |
| res.save_to_img(save_path=seg_dir) | |
| fnames = os.listdir(seg_dir) | |
| for fname in fnames: | |
| fpath = os.path.join(seg_dir, fname) | |
| if fname.endswith((".png", ".jpg", ".jpeg")) and ("res" in fname or "vis" in fname): | |
| vis_src = image_to_base64_data_url(fpath) | |
| new_vis = f'<div style="margin-bottom:20px; border: 2px solid #10b981; border-radius: 12px; overflow: hidden; background:white;">' | |
| new_vis += f'<img src="{vis_src}" alt="Page {i+1}" style="width:100%;"></div>' | |
| if new_vis not in vis_html: | |
| vis_html += new_vis | |
| yield "β Reconstructing structure...", vis_html, "β Reconstructing...", "{}", "" | |
| logger.info("π Restructuring pages for final output...") | |
| progress(0.9, desc="Restructuring pages...") | |
| reconstructed_output = pipeline.restructure_pages( | |
| pages_res, | |
| merge_tables=True, | |
| relevel_titles=True, | |
| concatenate_pages=True | |
| ) | |
| for i, res in enumerate(reconstructed_output): | |
| res.save_to_json(save_path=run_output_dir) | |
| res.save_to_markdown(save_path=run_output_dir) | |
| fnames = os.listdir(run_output_dir) | |
| for fname in fnames: | |
| fpath = os.path.join(run_output_dir, fname) | |
| if fname.endswith(".md"): | |
| with open(fpath, 'r', encoding='utf-8') as f: | |
| c = f.read() | |
| if c not in md_content: md_content += c + "\n\n" | |
| elif fname.endswith(".json"): | |
| with open(fpath, 'r', encoding='utf-8') as f: | |
| c = f.read() | |
| if c not in json_content: json_content += c + "\n\n" | |
| final_md = _escape_inequalities_in_math(md_content) or "β οΈ No text recognized." | |
| layout_map_src = draw_layout_map(img_path, json_content) | |
| layout_html = "" | |
| if layout_map_src: | |
| layout_html = f'<div style="border: 2px solid #3b82f6; border-radius: 12px; overflow: hidden;"><img src="{layout_map_src}" style="width:100%;"></div>' | |
| progress(1.0, desc="Complete") | |
| yield final_md, vis_html, md_content, json_content, layout_html | |
| logger.info("--- β Inference and Restructuring Finished ---") | |
| except Exception as e: | |
| logger.error(f"β Inference Error: {e}") | |
| logger.error(traceback.format_exc()) | |
| yield f"β Error: {str(e)}", "", "", "{}", "" | |
| # --- UI Layout --- | |
| custom_css = """ | |
| body, .gradio-container { font-family: 'Inter', system-ui, sans-serif; } | |
| .app-header { | |
| text-align: center; | |
| padding: 2.5rem; | |
| background: linear-gradient(135deg, #0284c7 0%, #10b981 100%); | |
| color: white; | |
| border-radius: 1.5rem; | |
| margin-bottom: 2rem; | |
| box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1); | |
| } | |
| .app-header h1 { color: white !important; font-weight: 800; font-size: 2.5rem; } | |
| .status-card { | |
| background-color: #0d9488 !important; | |
| color: white !important; | |
| padding: 1.25rem; | |
| border-radius: 1rem; | |
| margin-bottom: 2.5rem; | |
| font-weight: 600; | |
| text-align: center; | |
| box-shadow: inset 0 2px 4px 0 rgba(0, 0, 0, 0.06); | |
| } | |
| .status-card p { margin: 0; color: white !important; } | |
| .output-box { border: 1px solid #e2e8f0 !important; border-radius: 1rem !important; } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Ocean(), css=custom_css) as demo: | |
| gr.HTML( | |
| """ | |
| <div class="app-header"> | |
| <h1>π Tachiwin Document Parsing OCR π¦‘</h1> | |
| <p>Advancing linguistic rights with state-of-the-art document parsing</p> | |
| </div> | |
| """ | |
| ) | |
| status_msg = "Pipeline Ready" if pipeline else "Initialization in progress..." | |
| gr.HTML( | |
| f""" | |
| <div class="status-card"> | |
| <p><strong>β‘ Status:</strong> {status_msg} | | |
| <strong>Model:</strong> <code>{REPO_ID}</code> | | |
| <strong>Hardware:</strong> CPU</p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| img_input = gr.Image(label="Upload Image", type="filepath") | |
| mode_selector = gr.Dropdown( | |
| choices=["Full Document Parsing", "Formula Recognition", "Table Recognition", "Text Recognition", "Feature Spotting"], | |
| value="Full Document Parsing", | |
| label="Analysis Mode" | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| chart_check = gr.Checkbox(label="Chart Recognition", value=True) | |
| unwarp_check = gr.Checkbox(label="Document Unwarping / Orientation Fix", value=False) | |
| btn_run = gr.Button("π Start Analysis", variant="primary") | |
| with gr.Column(scale=7): | |
| with gr.Tabs(): | |
| with gr.Tab("π Markdown View"): | |
| md_view = gr.Markdown(latex_delimiters=LATEX_DELIMS, elem_classes="output-box") | |
| with gr.Tab("πΌοΈ Visual Results"): | |
| vis_view = gr.HTML('<div style="text-align:center; color:#94a3b8; padding: 50px;">Results will appear here.</div>') | |
| with gr.Tab("π Raw Source"): | |
| raw_view = gr.Code(language="markdown") | |
| with gr.Tab("πΎ JSON Feed"): | |
| json_view = gr.Code(language="json") | |
| with gr.Tab("πΊοΈ Layout Map"): | |
| layout_view = gr.HTML('<div style="text-align:center; color:#94a3b8; padding: 50px;">Bboxes view will appear here.</div>') | |
| def unified_wrapper(img, mode, chart, unwarp, progress=gr.Progress()): | |
| if not img: | |
| yield "β οΈ Please upload an image.", "", "", "{}", "" | |
| return | |
| yield "β Running prediction...", gr.update(value="<div style='text-align:center;'>β Processing...</div>"), "β Processing...", "{}", gr.update(value="<div style='text-align:center;'>β Generating map...</div>") | |
| for res_preview, res_vis, res_raw, res_json, res_map in run_inference(img, mode, chart, unwarp, progress=progress): | |
| yield res_preview, res_vis, res_raw, res_json, res_map | |
| btn_run.click( | |
| unified_wrapper, | |
| [img_input, mode_selector, chart_check, unwarp_check], | |
| [md_view, vis_view, raw_view, json_view, layout_view], | |
| show_progress="full" | |
| ) | |
| gr.Markdown("--- \n *Tachiwin Project: Indigenous Languages of Mexico.*") | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |