math-solver / vision_ocr /pipeline.py
Cuong2004
Deploy API from GitHub Actions
395651c
"""OCR vision stack only (no LLM). Used by Celery OCR worker and composed by ``agents.ocr_agent``."""
from __future__ import annotations
import logging
import os
import uuid
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
from .compat import allow_ultralytics_weights
logger = logging.getLogger(__name__)
_OCR_MAX_EDGE = 2000
_CROP_PAD = 4
class OcrVisionPipeline:
"""
Hybrid pipeline:
1. YOLO for layout analysis (weights preloaded; layout path reserved).
2. PaddleOCR for Vietnamese text extraction.
3. Pix2Tex for LaTeX formula extraction.
"""
def __init__(self) -> None:
logger.info("[OcrVisionPipeline] Initializing engines...")
try:
from ultralytics import YOLO
allow_ultralytics_weights()
logger.info("[OcrVisionPipeline] Loading YOLO...")
self.layout_model = YOLO("yolov8n.pt")
logger.info("[OcrVisionPipeline] YOLO initialized.")
except Exception as e:
logger.error("[OcrVisionPipeline] YOLO init failed: %s", e)
self.layout_model = None
try:
from paddleocr import PaddleOCR
logger.info("[OcrVisionPipeline] Loading PaddleOCR...")
self.text_model = PaddleOCR(use_angle_cls=True, lang="vi")
logger.info("[OcrVisionPipeline] PaddleOCR (vi) initialized.")
except Exception as e:
logger.error("[OcrVisionPipeline] PaddleOCR init failed: %s", e)
self.text_model = None
try:
from pix2tex.cli import LatexOCR
logger.info("[OcrVisionPipeline] Loading Pix2Tex...")
self.math_model = LatexOCR()
logger.info("[OcrVisionPipeline] Pix2Tex initialized.")
except Exception as e:
logger.error("[OcrVisionPipeline] Pix2Tex init failed: %s", e)
self.math_model = None
def _preprocess_image_for_ocr(self, src_path: str) -> Tuple[str, bool]:
"""Resize large images, CLAHE on luminance; returns path (may be new temp file)."""
img = cv2.imread(src_path, cv2.IMREAD_COLOR)
if img is None:
g = cv2.imread(src_path, cv2.IMREAD_GRAYSCALE)
if g is None:
logger.warning("[OcrVisionPipeline] OpenCV could not read %s; using original.", src_path)
return src_path, False
img = cv2.cvtColor(g, cv2.COLOR_GRAY2BGR)
h, w = img.shape[:2]
max_dim = max(h, w)
if max_dim > _OCR_MAX_EDGE:
scale = _OCR_MAX_EDGE / max_dim
img = cv2.resize(
img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA
)
logger.info("[OcrVisionPipeline] Resized for OCR to max edge %s", _OCR_MAX_EDGE)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(gray)
den = cv2.fastNlMeansDenoising(gray, None, 8, 7, 21)
out = f"temp_ocr_prep_{uuid.uuid4().hex}.png"
cv2.imwrite(out, den)
return out, True
def _load_bgr_for_crops(self, path: str) -> Optional[np.ndarray]:
im = cv2.imread(path, cv2.IMREAD_COLOR)
if im is None:
g = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
if g is None:
return None
im = cv2.cvtColor(g, cv2.COLOR_GRAY2BGR)
return im
def _crop_from_quad(self, img_bgr: np.ndarray, bbox) -> Optional[np.ndarray]:
try:
pts = np.array(bbox, dtype=np.float32)
xs = pts[:, 0]
ys = pts[:, 1]
H, W = img_bgr.shape[:2]
x1 = max(0, int(xs.min()) - _CROP_PAD)
y1 = max(0, int(ys.min()) - _CROP_PAD)
x2 = min(W, int(xs.max()) + _CROP_PAD)
y2 = min(H, int(ys.max()) + _CROP_PAD)
if x2 <= x1 or y2 <= y1:
return None
return img_bgr[y1:y2, x1:x2].copy()
except Exception as e:
logger.debug("[OcrVisionPipeline] crop failed: %s", e)
return None
def _latex_from_crop_bgr(self, crop_bgr: np.ndarray) -> Optional[str]:
if self.math_model is None or crop_bgr is None or crop_bgr.size == 0:
return None
ch, cw = crop_bgr.shape[:2]
if ch < 10 or cw < 10:
return None
try:
from PIL import Image
rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
pil = Image.fromarray(rgb)
out = self.math_model(pil)
if isinstance(out, str) and out.strip():
return out.strip()
except Exception as e:
logger.debug("[OcrVisionPipeline] Pix2Tex on crop failed: %s", e)
return None
def _maybe_math_from_crop(self, img_bgr: Optional[np.ndarray], bbox, text: str) -> str:
if img_bgr is None or not self.math_model:
return text
is_math_hint = any(
c in text for c in ["\\", "^", "_", "{", "}", "=", "+", "-", "*", "/"]
)
if not is_math_hint:
return text
crop = self._crop_from_quad(img_bgr, bbox)
latex = self._latex_from_crop_bgr(crop) if crop is not None else None
if latex:
logger.info("[OcrVisionPipeline] Pix2Tex replaced line fragment (len=%s)", len(latex))
return f"${latex}$"
return text
async def process_image(self, image_path: str) -> str:
"""Return assembled raw OCR text (no LLM)."""
logger.info("==[OcrVisionPipeline] Processing: %s==", image_path)
if not os.path.exists(image_path):
return f"Error: File {image_path} not found."
prep_path, prep_cleanup = self._preprocess_image_for_ocr(image_path)
paddle_path = prep_path if prep_cleanup else image_path
img_bgr = self._load_bgr_for_crops(prep_path if prep_cleanup else image_path)
raw_fragments: List[Dict[str, Any]] = []
try:
if self.text_model:
logger.info("[OcrVisionPipeline] Running PaddleOCR on %s...", paddle_path)
result = self.text_model.ocr(paddle_path)
logger.info("[OcrVisionPipeline] PaddleOCR raw result: %s", result)
if not result:
logger.warning("[OcrVisionPipeline] PaddleOCR returned no results.")
return ""
if isinstance(result[0], dict):
res_dict = result[0]
rec_texts = res_dict.get("rec_texts", [])
rec_scores = res_dict.get("rec_scores", [])
rec_polys = res_dict.get("rec_polys", [])
for i in range(len(rec_texts)):
text = rec_texts[i]
bbox = rec_polys[i]
score = rec_scores[i] if i < len(rec_scores) else None
if score is not None and float(score) < 0.45:
logger.debug(
"[OcrVisionPipeline] Low-confidence line (score=%s): %s",
score,
text[:80],
)
y_top = int(min(p[1] for p in bbox)) if hasattr(bbox, "__iter__") else 0
content = self._maybe_math_from_crop(img_bgr, bbox, text)
raw_fragments.append({"y": y_top, "content": content, "type": "text"})
elif isinstance(result[0], list):
for line in result[0]:
bbox = line[0]
text = line[1][0]
score = line[1][1] if len(line[1]) > 1 else None
if score is not None and float(score) < 0.45:
logger.debug(
"[OcrVisionPipeline] Low-confidence line (score=%s): %s",
score,
text[:80],
)
y_top = int(bbox[0][1])
content = self._maybe_math_from_crop(img_bgr, bbox, text)
raw_fragments.append({"y": y_top, "content": content, "type": "text"})
finally:
if prep_cleanup and os.path.exists(prep_path):
try:
os.remove(prep_path)
except OSError:
pass
raw_fragments.sort(key=lambda x: x["y"])
combined_text = "\n".join([f["content"] for f in raw_fragments])
logger.info(
"[OcrVisionPipeline] Raw OCR output assembled:\n---\n%s\n---", combined_text
)
if not combined_text.strip():
logger.warning("[OcrVisionPipeline] No text detected.")
return ""
return combined_text
async def process_url(self, url: str) -> str:
"""Download image and run ``process_image`` (raw only)."""
import httpx
from app.url_utils import sanitize_url
url = sanitize_url(url)
if not url:
return "Error: Empty image URL after cleanup."
async with httpx.AsyncClient() as client:
resp = await client.get(url)
if resp.status_code == 200:
temp_path = "temp_url_image.png"
with open(temp_path, "wb") as f:
f.write(resp.content)
try:
return await self.process_image(temp_path)
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
return f"Error: Failed to fetch image from URL {url}"