engineering-drawing / src /inference.py
Harry Pham
update OCR
d640e48
# src/inference.py
import torch
_orig_torch_load = torch.load
def _patched_load(*args, **kwargs):
kwargs.setdefault("weights_only", False)
return _orig_torch_load(*args, **kwargs)
torch.load = _patched_load
import cv2
import json
import numpy as np
from pathlib import Path
from ultralytics import RTDETR
import re
from difflib import SequenceMatcher
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Device: {DEVICE}")
CLASS_NAMES = ["note", "part-drawing", "table"]
CLASS_DISPLAY = {"note": "Note", "part-drawing": "PartDrawing", "table": "Table"}
COLORS = {"note": (0,165,255), "part-drawing": (0,200,0), "table": (0,0,220)}
_det_model = None
_ocr_paddle = None
_ocr_paddle_en = None
_ocr_easyocr = None
_ocr_vietocr = None
REALESRGAN_AVAILABLE = False
_esrgan_upsampler = None # Thêm biến global
try:
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
REALESRGAN_AVAILABLE = True
print("[INFO] Real-ESRGAN is available")
except ImportError:
print("[WARN] Real-ESRGAN not installed. Install: pip install realesrgan basicsr")
def get_esrgan_upsampler():
global _esrgan_upsampler
if not REALESRGAN_AVAILABLE:
return None
if _esrgan_upsampler is None:
try:
print("[INFO] Loading Real-ESRGAN model...")
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
_esrgan_upsampler = RealESRGANer(
scale=4,
model_path='weights/RealESRGAN_x4plus_anime_6B.pth',
model=model,
device=DEVICE
)
except Exception as e:
print(f"[WARN] Failed to load Real-ESRGAN: {e}")
return None
return _esrgan_upsampler
def upscale_if_needed(img_bgr, min_dim=300):
"""Upscale image using Real-ESRGAN if both dimensions are below threshold."""
h, w = img_bgr.shape[:2]
if h < min_dim or w < min_dim:
upsampler = get_esrgan_upsampler()
if upsampler is not None:
try:
output, _ = upsampler.enhance(img_bgr, outscale=2)
return output
except Exception as e:
print(f"[WARN] ESRGAN upscale failed: {e}")
return img_bgr
# ============================================================
# DOMAIN DICTIONARY — Từ điển bản vẽ kỹ thuật Việt Nam
# ============================================================
# Từ điển các từ thường gặp trong bảng kê bản vẽ kỹ thuật
TECH_DICTIONARY = {
# Tên chi tiết
"bọc táp": "Bọc táp",
"boc tap": "Bọc táp",
"bạc táp": "Bọc táp",
"bọc tốp": "Bọc táp",
"bọc tot": "Bọc táp",
"vòng đệm": "Vòng đệm",
"vong dem": "Vòng đệm",
"vòng dệm": "Vòng đệm",
"vong đệm": "Vòng đệm",
"chốt trụ": "Chốt trụ",
"chot tru": "Chốt trụ",
"chốt trự": "Chốt trụ",
"chot trụ": "Chốt trụ",
"vít": "Vít",
"vit": "Vít",
"bu lông": "Bu lông",
"bu long": "Bu lông",
"bulong": "Bu lông",
"bu-lông": "Bu lông",
"bulông": "Bu lông",
"vòng đệm vênh": "Vòng đệm vênh",
"vong dem venh": "Vòng đệm vênh",
"vòng dệm vênh": "Vòng đệm vênh",
"then bằng": "Then bằng",
"then bang": "Then bằng",
"then bảng": "Then bằng",
"ống dẫn": "Ống dẫn",
"ong dan": "Ống dẫn",
"ống chẫn": "Ống dẫn",
"ống dần": "Ống dẫn",
"ông dẫn": "Ống dẫn",
"ông chẫn": "Ống dẫn",
"ống chến": "Ống dẫn",
"ông chến": "Ống dẫn",
"chốt chặn": "Chốt chặn",
"chot chan": "Chốt chặn",
"chốt chắn": "Chốt chặn",
"cnốt chến": "Chốt chặn",
"chốt chén": "Chốt chặn",
"bạc lót": "Bạc lót",
"bac lot": "Bạc lót",
"bọc lót": "Bạc lót",
"bạc lốt": "Bạc lót",
"bọc lết": "Bạc lót",
"bọc lết": "Bạc lót",
"giá đỡ": "Giá đỡ",
"gia do": "Giá đỡ",
"giá dở": "Giá đỡ",
"giá đở": "Giá đỡ",
"bánh răng": "Bánh răng",
"banh rang": "Bánh răng",
"bành răng": "Bánh răng",
"bánh rằng": "Bánh răng",
"bảnh răng": "Bánh răng",
"bdnh răng": "Bánh răng",
"bdình răng": "Bánh răng",
"hộp bánh răng": "Hộp bánh răng",
"hop banh rang": "Hộp bánh răng",
"hộp bành răng": "Hộp bánh răng",
"mộp bành răng": "Hộp bánh răng",
"mộp bánh răng": "Hộp bánh răng",
"nắp": "Nắp",
"nap": "Nắp",
"năp": "Nắp",
"nốp": "Nắp",
# Vật liệu
"đồng nhôm": "Đồng nhôm",
"dong nhom": "Đồng nhôm",
"đồng thanh": "Đồng nhôm",
"đồng thann": "Đồng nhôm",
"đổng nhôm": "Đồng nhôm",
"đống nhôm": "Đồng nhôm",
"đống thanh": "Đồng nhôm",
"thép ct3": "Thép CT3",
"thep ct3": "Thép CT3",
"thếp ct3": "Thép CT3",
"tnép ct3": "Thép CT3",
"thếp cts": "Thép CT3",
"tnếp ct3": "Thép CT3",
"thếp ctj": "Thép CT3",
"tnép ctj": "Thép CT3",
"tnếp ctj": "Thép CT3",
"thep ctj": "Thép CT3",
"thép 65": "Thép 65",
"thep 65": "Thép 65",
"thếp 65": "Thép 65",
"tnếp 65": "Thép 65",
"thếp 65f": "Thép 65",
"tnép 65f": "Thép 65",
"thép 6sf": "Thép 65",
"thep 6sf": "Thép 65",
"thếp 6sf": "Thép 65",
"tnép 6sf": "Thép 65",
"thép 45": "Thép 45",
"thep 45": "Thép 45",
"thếp 45": "Thép 45",
"tnếp 45": "Thép 45",
"sắt tây": "Sắt tây",
"sat tay": "Sắt tây",
"sắt tay": "Sắt tây",
"sdi tay": "Sắt tây",
"sdi day": "Sắt tây",
"sdi đay": "Sắt tây",
"gang 15-32": "Gang 15-32",
"gang15-32": "Gang 15-32",
"gong 15-32": "Gang 15-32",
"gong15-32": "Gang 15-32",
"gang 15.32": "Gang 15-32",
"gang 15 32": "Gang 15-32",
"gong 15.32": "Gang 15-32",
"gang1532": "Gang 15-32",
# Header
"vị trí": "Vị trí",
"vi tri": "Vị trí",
"v.trí": "Vị trí",
"tên chi tiết": "Tên chi tiết",
"ten chi tiet": "Tên chi tiết",
"tên chi tiết máy": "Tên chi tiết máy",
"ten chi tiet may": "Tên chi tiết máy",
"số lg": "Số lg",
"so lg": "Số lg",
"số lượng": "Số lg",
"so luong": "Số lg",
"s.lg": "Số lg",
"số lý": "Số lg",
"vật liệu": "Vật liệu",
"vat lieu": "Vật liệu",
"vat liéu": "Vật liệu",
"ghi chú": "Ghi chú",
"ghi chu": "Ghi chú",
# Title block
"bản vẽ số": "Bản vẽ số",
"ban ve so": "Bản vẽ số",
"bản gối": "Bản vẽ số",
"bơm bánh răng": "BƠM BÁNH RĂNG",
"bom banh rang": "BƠM BÁNH RĂNG",
"bớm bánh răng": "BƠM BÁNH RĂNG",
"bản vẽ lắp số": "Bản vẽ lắp số",
"ban ve lap so": "Bản vẽ lắp số",
"bản vể lắp số": "Bản vẽ lắp số",
"bán vẽ lắp số": "Bản vẽ lắp số",
"bán vể lắp số": "Bản vẽ lắp số",
"tỷ lệ": "Tỷ lệ",
"ty le": "Tỷ lệ",
"tý lệ": "Tỷ lệ",
"bộ môn hình hoạ": "Bộ môn Hình hoạ",
"bộ môn hình họa": "Bộ môn Hình hoạ",
"bo mon hinh hoa": "Bộ môn Hình hoạ",
"bộ mốn hình hoạ": "Bộ môn Hình hoạ",
"đại học bách khoa hà nội": "Đại học Bách khoa Hà Nội",
"dai hoc bach khoa ha noi": "Đại học Bách khoa Hà Nội",
"đại học bách khoa": "Đại học Bách khoa Hà Nội",
"bại hoc bách khoa": "Đại học Bách khoa Hà Nội",
"bại học bách khoa hà nội": "Đại học Bách khoa Hà Nội",
}
# Canonical part names for fuzzy matching
CANONICAL_PARTS = [
"Bọc táp", "Vòng đệm", "Chốt trụ", "Vít", "Bu lông",
"Vòng đệm vênh", "Then bằng", "Ống dẫn", "Chốt chặn",
"Bạc lót", "Giá đỡ", "Bánh răng", "Hộp bánh răng", "Nắp",
]
CANONICAL_MATERIALS = [
"Đồng nhôm", "Thép CT3", "Thép 65", "Thép 45",
"Sắt tây", "Gang 15-32",
]
CANONICAL_HEADERS = [
"Vị trí", "Tên chi tiết", "Tên chi tiết máy", "Số lg",
"Vật liệu", "Ghi chú",
]
def fuzzy_match(text, candidates, threshold=0.55):
"""Fuzzy match text against candidates, return best match if above threshold."""
if not text or not candidates:
return text
text_lower = text.lower().strip()
# Exact match in dictionary first
if text_lower in TECH_DICTIONARY:
return TECH_DICTIONARY[text_lower]
# Fuzzy match
best_match = None
best_score = 0
for candidate in candidates:
score = SequenceMatcher(None, text_lower, candidate.lower()).ratio()
if score > best_score:
best_score = score
best_match = candidate
if best_score >= threshold:
return best_match
return text
def correct_technical_text(text, column_type="auto"):
"""
Sửa lỗi OCR dựa trên domain knowledge.
column_type: "position", "name", "quantity", "material", "note", "auto"
"""
if not text or not text.strip():
return text
original = text.strip()
text_lower = original.lower()
# 1. Exact dictionary lookup
if text_lower in TECH_DICTIONARY:
return TECH_DICTIONARY[text_lower]
# 2. Column-specific corrections
if column_type == "position" or (column_type == "auto" and original.replace('.','').replace(',','').isdigit()):
# Position column — should be a number
cleaned = re.sub(r'[^0-9]', '', original)
if cleaned:
return cleaned
return original
if column_type == "quantity" or (column_type == "auto" and len(original) <= 2 and any(c.isdigit() for c in original)):
cleaned = re.sub(r'[^0-9]', '', original)
if cleaned:
return cleaned
return original
if column_type == "name":
# Try fuzzy match against known part names
result = fuzzy_match(original, CANONICAL_PARTS, threshold=0.5)
if result != original:
return result
# Also check headers
result = fuzzy_match(original, CANONICAL_HEADERS, threshold=0.5)
if result != original:
return result
if column_type == "material":
result = fuzzy_match(original, CANONICAL_MATERIALS, threshold=0.5)
if result != original:
return result
if column_type == "auto":
# Try all categories
for candidates in [CANONICAL_PARTS, CANONICAL_MATERIALS, CANONICAL_HEADERS]:
result = fuzzy_match(original, candidates, threshold=0.55)
if result != original:
return result
# 3. General corrections
text_out = original
# Fix common OCR character substitutions
# M followed by digits (bolt/screw specs)
text_out = re.sub(r'[Mm]\s*(\d)', r'M\1', text_out)
# Fix: "5x8-35" style dimensions
text_out = re.sub(r'(\d+)\s*[xX×]\s*(\d+)\s*[-–]\s*(\d+)', r'\1x\2-\3', text_out)
text_out = re.sub(r'(\d+)\s*[xX×]\s*(\d+)\s*[xX×]\s*(\d+)', r'\1x\2x\3', text_out)
# Fix: "3n8.35" → "5x8-35" (common OCR error for handwriting)
text_out = re.sub(r'(\d+)\s*n\s*(\d+)', r'\1x\2', text_out)
# Fix: dimension specs like "4x6x14" or "4*6*14"
text_out = re.sub(r'(\d+)\s*[*]\s*(\d+)\s*[*]\s*(\d+)', r'\1x\2x\3', text_out)
return text_out
def correct_table_row(row, num_columns=5):
"""
Sửa lỗi cho toàn bộ 1 row, biết vị trí cột.
Columns: [Vị trí, Tên chi tiết, Số lg, Vật liệu, Ghi chú]
"""
if not row:
return row
corrected = list(row)
# Pad to expected columns
while len(corrected) < num_columns:
corrected.append("")
# Trim excess
if len(corrected) > num_columns:
corrected = corrected[:num_columns]
# Column 0: Vị trí (number)
if corrected[0]:
corrected[0] = correct_technical_text(corrected[0], "position")
# Column 1: Tên chi tiết (part name)
if corrected[1]:
corrected[1] = correct_technical_text(corrected[1], "name")
# Column 2: Số lg (quantity - number)
if corrected[2]:
corrected[2] = correct_technical_text(corrected[2], "quantity")
# Column 3: Vật liệu (material)
if corrected[3]:
corrected[3] = correct_technical_text(corrected[3], "material")
# Column 4: Ghi chú (note - keep as-is mostly)
if corrected[4]:
corrected[4] = correct_technical_text(corrected[4], "auto")
return corrected
# ============================================================
# MODEL LOADERS
# ============================================================
def get_det_model(checkpoint="best.pt"):
global _det_model
if _det_model is None:
print(f"[INFO] Loading detection model: {checkpoint}")
_det_model = RTDETR(checkpoint)
return _det_model
# ============================================================
# SURYA OCR (optional)
# ============================================================
SURYA_AVAILABLE = False
try:
from surya.ocr import run_ocr
from surya.model.detection.model import load_det_processor, load_det_model
from surya.model.recognition.model import load_rec_model
from surya.model.recognition.processor import load_rec_processor
SURYA_AVAILABLE = True
print("[INFO] Surya OCR is available")
except ImportError:
print("[WARN] Surya OCR not installed. Install with: pip install surya-ocr")
def ocr_with_surya(img_bgr, langs=["vi", "en"]):
if not SURYA_AVAILABLE:
raise ImportError("Surya OCR is not installed.")
det_processor, det_model = load_det_processor(), load_det_model()
rec_model, rec_processor = load_rec_model(), load_rec_processor()
from PIL import Image
pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
predictions = run_ocr([pil_img], [langs], det_model, det_processor,
rec_model, rec_processor)
texts = [line.text for line in predictions[0].text_lines]
return "\n".join(texts)
# ============================================================
# VietOCR (optional - tốt cho chữ viết tay tiếng Việt)
# ============================================================
VIETOCR_AVAILABLE = False
try:
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg
VIETOCR_AVAILABLE = True
print("[INFO] VietOCR is available")
except ImportError:
print("[WARN] VietOCR not installed. Install with: pip install vietocr")
def get_vietocr():
global _ocr_vietocr
if _ocr_vietocr is None and VIETOCR_AVAILABLE:
try:
config = Cfg.load_config_from_name('vgg_transformer')
config['cnn']['pretrained'] = True
config['device'] = DEVICE
_ocr_vietocr = Predictor(config)
print("[INFO] VietOCR loaded successfully")
except Exception as e:
print(f"[WARN] VietOCR load failed: {e}")
return _ocr_vietocr
def ocr_line_vietocr(img_bgr):
"""OCR a single text line image using VietOCR."""
predictor = get_vietocr()
if predictor is None:
return ""
from PIL import Image
pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
text = predictor.predict(pil_img)
return text.strip()
# ============================================================
# PaddleOCR / EasyOCR
# ============================================================
def get_paddle_reader(lang='vi'):
global _ocr_paddle, _ocr_paddle_en
if lang == 'en':
if _ocr_paddle_en is not None:
return _ocr_paddle_en
else:
if _ocr_paddle is not None:
return _ocr_paddle
try:
from paddleocr import PaddleOCR
print(f"[INFO] Initializing PaddleOCR PP-OCRv4 (lang={lang})...")
reader = PaddleOCR(
lang=lang,
use_angle_cls=True,
use_gpu=(DEVICE == "cuda"),
show_log=False,
ocr_version='PP-OCRv4',
det_db_thresh=0.15,
det_db_box_thresh=0.2,
det_db_unclip_ratio=2.0,
use_dilation=True,
det_db_score_mode='slow',
rec_image_shape="3,48,320",
max_text_length=80,
rec_batch_num=6,
)
if lang == 'en':
_ocr_paddle_en = reader
else:
_ocr_paddle = reader
return reader
except Exception as e:
print(f"[WARN] PaddleOCR init failed: {e}")
return None
def get_easyocr_reader():
global _ocr_easyocr
if _ocr_easyocr is None:
import easyocr
_ocr_easyocr = easyocr.Reader(
["vi", "en"], gpu=(DEVICE == "cuda"), verbose=False
)
return _ocr_easyocr
# ============================================================
# PREPROCESSING
# ============================================================
def enhance_faded_text(img_bgr):
"""Giải pháp 4: Unsharp Masking kết hợp Local Thresholding cho nét chữ mờ"""
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# 1. Unsharp Masking (Tăng cường cạnh/nét chữ)
gaussian = cv2.GaussianBlur(gray, (0, 0), 2.0)
unsharp = cv2.addWeighted(gray, 1.5, gaussian, -0.5, 0)
# 2. Ngưỡng cục bộ (Local Thresholding)
try:
from skimage.filters import threshold_sauvola
window_size = 25
thresh = threshold_sauvola(unsharp, window_size=window_size)
binary = (unsharp > thresh) * 255
binary = binary.astype(np.uint8)
except ImportError:
# Fallback về OpenCV nếu chưa cài scikit-image
binary = cv2.adaptiveThreshold(unsharp, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY, 21, 10)
return cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR)
def preprocess_for_ocr(img_bgr, min_width=1500, mode="note"):
h, w = img_bgr.shape[:2]
if w < min_width:
scale = min_width / w
img_bgr = cv2.resize(img_bgr, None, fx=scale, fy=scale,
interpolation=cv2.INTER_CUBIC)
h, w = img_bgr.shape[:2]
if mode == "note":
img_proc = cv2.bilateralFilter(img_bgr, 9, 75, 75)
lab = cv2.cvtColor(img_proc, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
l = clahe.apply(l)
lab = cv2.merge([l, a, b])
img_proc = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
kernel = np.array([[0, -0.5, 0],
[-0.5, 3, -0.5],
[0, -0.5, 0]])
img_proc = cv2.filter2D(img_proc, -1, kernel)
return img_proc
else: # table
img_proc = cv2.bilateralFilter(img_bgr, 11, 80, 80)
lab = cv2.cvtColor(img_proc, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(4, 4))
l = clahe.apply(l)
lab = cv2.merge([l, a, b])
img_proc = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
return img_proc
def preprocess_for_handwriting(img_bgr, min_width=1800):
"""
Tiền xử lý đặc biệt cho chữ viết tay.
Tăng contrast mạnh, loại bỏ đường kẻ bảng, giữ nét chữ.
"""
h, w = img_bgr.shape[:2]
if w < min_width:
scale = min_width / w
img_bgr = cv2.resize(img_bgr, None, fx=scale, fy=scale,
interpolation=cv2.INTER_CUBIC)
# Convert to grayscale
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# Remove horizontal and vertical lines (table borders)
h_img, w_img = gray.shape
# Detect and remove horizontal lines
h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (max(40, w_img // 10), 1))
h_lines = cv2.morphologyEx(~gray, cv2.MORPH_OPEN, h_kernel, iterations=1)
gray_no_lines = gray.copy()
gray_no_lines[h_lines > 128] = 255
# Detect and remove vertical lines
v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, max(40, h_img // 10)))
v_lines = cv2.morphologyEx(~gray, cv2.MORPH_OPEN, v_kernel, iterations=1)
gray_no_lines[v_lines > 128] = 255
# CLAHE for better contrast
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray_no_lines)
# Adaptive threshold — good for handwriting
binary = cv2.adaptiveThreshold(enhanced, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY, 21, 10)
# Light morphological cleaning
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=1)
return cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR)
def preprocess_grayscale_variant(img_bgr, min_width=1500):
h, w = img_bgr.shape[:2]
if w < min_width:
scale = min_width / w
img_bgr = cv2.resize(img_bgr, None, fx=scale, fy=scale,
interpolation=cv2.INTER_CUBIC)
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
gray = clahe.apply(gray)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR)
# ============================================================
# OCR FUNCTIONS
# ============================================================
def ocr_single_pass(reader, img_bgr):
"""Run OCR once, return (list_of_dicts, avg_confidence)."""
if hasattr(reader, 'ocr'): # PaddleOCR
result = reader.ocr(img_bgr, cls=True)
if not result or not result[0]:
return [], 0.0
items = []
confs = []
for line in result[0]:
box, (text, conf) = line
if conf >= 0.15 and text.strip():
xs = [p[0] for p in box]
ys = [p[1] for p in box]
items.append({
"text": text.strip(),
"conf": conf,
"x": np.mean(xs),
"y": np.mean(ys),
"x1": min(xs), "y1": min(ys),
"x2": max(xs), "y2": max(ys),
"box": box
})
confs.append(conf)
avg_conf = np.mean(confs) if confs else 0.0
return items, avg_conf
else: # EasyOCR
results = reader.readtext(img_bgr, detail=1, paragraph=False)
items = []
confs = []
for (pts, text, conf) in results:
if conf >= 0.1 and text.strip():
xs = [p[0] for p in pts]
ys = [p[1] for p in pts]
items.append({
"text": text.strip(),
"conf": conf,
"x": np.mean(xs),
"y": np.mean(ys),
"x1": min(xs), "y1": min(ys),
"x2": max(xs), "y2": max(ys),
"box": pts
})
confs.append(conf)
avg_conf = np.mean(confs) if confs else 0.0
return items, avg_conf
def multi_pass_ocr(img_bgr, reader, ocr_type="note"):
"""Multi-pass OCR with different preprocessings."""
best_items = []
best_conf = 0.0
# [NẾU ẢNH NHỎ LÀ DO CẮT TỪ GÓC, ÚP SCALE LUÔN TRƯỚC KHI LÀM GÌ ĐÓ]
img_bgr = upscale_if_needed(img_bgr, min_dim=400)
# Pass 1: Color preprocessing
img_v1 = preprocess_for_ocr(img_bgr, min_width=1500, mode=ocr_type)
items1, conf1 = ocr_single_pass(reader, img_v1)
if conf1 > best_conf:
best_conf = conf1
best_items = items1
# Pass 2: Handwriting-optimized preprocessing
img_v2 = preprocess_for_handwriting(img_bgr, min_width=1800)
items2, conf2 = ocr_single_pass(reader, img_v2)
if conf2 > best_conf:
best_conf = conf2
best_items = items2
# Pass 3: Extra upscale
img_v3 = preprocess_for_ocr(img_bgr, min_width=2500, mode=ocr_type)
items3, conf3 = ocr_single_pass(reader, img_v3)
if conf3 > best_conf:
best_conf = conf3
best_items = items3
# Pass 4: Grayscale Otsu
img_v4 = preprocess_grayscale_variant(img_bgr, min_width=1500)
items4, conf4 = ocr_single_pass(reader, img_v4)
if conf4 > best_conf:
best_conf = conf4
best_items = items4
# --- THÊM PASS 5: Giải quyết chữ bị mờ, lợt ---
img_v5 = enhance_faded_text(img_bgr)
items5, conf5 = ocr_single_pass(reader, img_v5)
if conf5 > best_conf:
best_conf = conf5
best_items = items5
print(f" Multi-pass confidences: {conf1:.3f}, {conf2:.3f}, {conf3:.3f}, {conf4:.3f}, {conf5:.3f} → best={best_conf:.3f}")
return best_items, best_conf
# ============================================================
# TABLE STRUCTURE — Intersection-based cell detection
# ============================================================
def detect_lines(gray, direction="horizontal", min_length_ratio=0.15):
"""Detect lines in image."""
h, w = gray.shape
_, binary = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)
if direction == "horizontal":
kernel_len = max(30, int(w * min_length_ratio))
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_len, 1))
else:
kernel_len = max(30, int(h * min_length_ratio))
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, kernel_len))
lines_img = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=2)
# Dilate slightly to connect broken lines
dilate_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
lines_img = cv2.dilate(lines_img, dilate_kernel, iterations=1)
return lines_img
def find_line_positions(lines_img, direction="horizontal", merge_distance=10):
"""Find positions of lines (y-coords for horizontal, x-coords for vertical)."""
if direction == "horizontal":
projection = np.sum(lines_img, axis=1)
else:
projection = np.sum(lines_img, axis=0)
# Find peaks
threshold = np.max(projection) * 0.3
positions = np.where(projection > threshold)[0]
if len(positions) == 0:
return []
# Merge close positions
merged = [positions[0]]
for pos in positions[1:]:
if pos - merged[-1] > merge_distance:
merged.append(pos)
else:
# Take average
merged[-1] = (merged[-1] + pos) // 2
return merged
def detect_table_cells_by_intersection(img_bgr):
"""
Detect table cells by finding intersections of horizontal and vertical lines.
Returns list of cells as (x1, y1, x2, y2) tuples, organized in grid.
"""
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
h, w = gray.shape
# Detect horizontal and vertical lines
h_lines = detect_lines(gray, "horizontal", min_length_ratio=0.1)
v_lines = detect_lines(gray, "vertical", min_length_ratio=0.1)
# Find line positions
y_positions = find_line_positions(h_lines, "horizontal", merge_distance=max(8, h//50))
x_positions = find_line_positions(v_lines, "vertical", merge_distance=max(8, w//50))
print(f" Table grid: {len(y_positions)} horizontal × {len(x_positions)} vertical lines")
if len(y_positions) < 2 or len(x_positions) < 2:
# Fallback to contour-based detection
return detect_table_structure(img_bgr), None
# Generate cells from grid intersections
cells = []
grid = []
for i in range(len(y_positions) - 1):
row_cells = []
for j in range(len(x_positions) - 1):
x1, y1 = x_positions[j], y_positions[i]
x2, y2 = x_positions[j + 1], y_positions[i + 1]
# Filter tiny cells
if (x2 - x1) < 10 or (y2 - y1) < 10:
continue
cells.append((x1, y1, x2, y2))
row_cells.append((x1, y1, x2, y2))
if row_cells:
grid.append(row_cells)
return cells, grid
def detect_table_structure(img_bgr):
"""Fallback contour-based cell detection."""
h, w = img_bgr.shape[:2]
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV)
h_kernel_len = max(40, w // 15)
v_kernel_len = max(40, h // 15)
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (h_kernel_len, 1))
horizontal_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, v_kernel_len))
vertical_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, vertical_kernel, iterations=2)
table_structure = cv2.add(horizontal_lines, vertical_lines)
contours, _ = cv2.findContours(table_structure, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cells = []
min_cell_area = (w * h) * 0.001
max_cell_area = (w * h) * 0.85
for cnt in contours:
x, y, cw, ch = cv2.boundingRect(cnt)
area = cw * ch
if min_cell_area < area < max_cell_area and cw > 15 and ch > 15:
cells.append((x, y, x + cw, y + ch))
cells = sorted(set(cells), key=lambda r: (r[1], r[0]))
return cells
# ============================================================
# OCR TABLE — Grid-based approach
# ============================================================
def ocr_cell_improved(img_cell, backend="paddle"):
if img_cell is None or img_cell.size == 0:
return ""
# Upscale very small cells with ESRGAN
img_cell = upscale_if_needed(img_cell, min_dim=150)
h, w = img_cell.shape[:2]
if h < 5 or w < 5:
return ""
# Upscale small cells
target_h = max(64, h)
if h < target_h:
scale = target_h / h
img_cell = cv2.resize(img_cell, None, fx=scale, fy=scale,
interpolation=cv2.INTER_CUBIC)
target_w = max(200, w)
if w < target_w:
scale_w = target_w / w
if scale_w > 1:
img_cell = cv2.resize(img_cell, None, fx=scale_w, fy=scale_w,
interpolation=cv2.INTER_CUBIC)
best_text = ""
best_conf = 0
# Try VietOCR first (better for handwriting)
if VIETOCR_AVAILABLE:
try:
vietocr_text = ocr_line_vietocr(img_cell)
if vietocr_text:
best_text = vietocr_text
best_conf = 0.7 # Default confidence for VietOCR
except Exception as e:
pass
# Try PaddleOCR / EasyOCR
if backend == "paddle":
reader = get_paddle_reader('vi')
elif backend == "surya":
text = ocr_with_surya(img_cell, langs=["vi", "en"])
if text.strip():
return text.strip()
reader = get_paddle_reader('vi')
else:
reader = get_easyocr_reader()
if reader is None:
reader = get_easyocr_reader()
if reader is None:
return best_text
# Variant 1: Color with CLAHE
img_proc1 = cv2.bilateralFilter(img_cell, 5, 50, 50)
lab = cv2.cvtColor(img_proc1, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(4, 4))
l = clahe.apply(l)
lab = cv2.merge([l, a, b])
img_proc1 = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
items1, conf1 = ocr_single_pass(reader, img_proc1)
text1 = " ".join([it["text"] for it in items1])
if conf1 > best_conf and text1.strip():
best_conf = conf1
best_text = text1
# Variant 2: Handwriting preprocessing (remove lines)
img_proc2 = preprocess_for_handwriting(img_cell, min_width=300)
items2, conf2 = ocr_single_pass(reader, img_proc2)
text2 = " ".join([it["text"] for it in items2])
if conf2 > best_conf and text2.strip():
best_conf = conf2
best_text = text2
# Variant 3: Binary Otsu
gray = cv2.cvtColor(img_cell, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
img_proc3 = cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR)
items3, conf3 = ocr_single_pass(reader, img_proc3)
text3 = " ".join([it["text"] for it in items3])
if conf3 > best_conf and text3.strip():
best_conf = conf3
best_text = text3
# --- THÊM VARIANT 4: Dành cho nét chữ viết tay bị mờ/đứt nét ---
img_proc4 = enhance_faded_text(img_cell)
items4, conf4 = ocr_single_pass(reader, img_proc4)
text4 = " ".join([it["text"] for it in items4])
if conf4 > best_conf and text4.strip():
best_conf = conf4
best_text = text4
# Also try English PaddleOCR for specs like "M6x50", "CT3"
# Also try English PaddleOCR for specs like "M6x50", "CT3"
if backend == "paddle":
reader_en = get_paddle_reader('en')
if reader_en:
items_en, conf_en = ocr_single_pass(reader_en, img_proc1)
text_en = " ".join([it["text"] for it in items_en])
if conf_en > best_conf and text_en.strip():
# Only prefer English if it looks like specs/numbers
if re.search(r'[A-Z0-9]', text_en):
best_conf = conf_en
best_text = text_en
return best_text
def ocr_table_grid(img, backend="paddle"):
"""
OCR table using grid-based cell detection.
Key improvement: detect grid structure first, then OCR each cell.
"""
result = detect_table_cells_by_intersection(img)
if isinstance(result, tuple):
cells, grid = result
else:
cells = result
grid = None
if grid and len(grid) > 0:
print(f" Grid detected: {len(grid)} rows")
# OCR each cell in grid order
all_rows = []
for row_idx, row_cells in enumerate(grid):
row_texts = []
for cell_box in row_cells:
x1, y1, x2, y2 = cell_box
# Extract cell with padding
pad = 3
cy1 = max(0, y1 + pad) # +pad to skip the line itself
cx1 = max(0, x1 + pad)
cy2 = min(img.shape[0], y2 - pad)
cx2 = min(img.shape[1], x2 - pad)
if cy2 <= cy1 or cx2 <= cx1:
row_texts.append("")
continue
cell_img = img[cy1:cy2, cx1:cx2]
text = ocr_cell_improved(cell_img, backend=backend)
row_texts.append(text.strip())
if any(t for t in row_texts): # Skip empty rows
all_rows.append(row_texts)
if all_rows:
# Determine number of columns (use most common column count)
col_counts = [len(r) for r in all_rows]
if col_counts:
expected_cols = max(set(col_counts), key=col_counts.count)
# Normalize rows to same column count
normalized_rows = []
for row in all_rows:
if len(row) < expected_cols:
row = row + [""] * (expected_cols - len(row))
elif len(row) > expected_cols:
row = row[:expected_cols]
normalized_rows.append(row)
# Apply domain correction
corrected_rows = []
for row in normalized_rows:
if expected_cols >= 4:
corrected = correct_table_row(row, num_columns=expected_cols)
else:
corrected = [correct_technical_text(cell) for cell in row]
corrected_rows.append(corrected)
text = "\n".join(" | ".join(r) for r in corrected_rows)
return {"rows": corrected_rows, "text": text}
# Fallback if grid detection failed
return None
def ocr_table(img_path, backend="paddle"):
img = cv2.imread(img_path)
if img is None:
return {"rows": [], "text": ""}
# Strategy 1: Grid-based cell detection + OCR
print(f" Trying grid-based table OCR...")
result = ocr_table_grid(img, backend)
if result and result.get("rows"):
print(f" Grid OCR: {len(result['rows'])} rows")
return result
# Strategy 2: PPStructure (if paddle backend)
if backend == "paddle":
pp_engine = get_pp_structure()
if pp_engine is not None:
try:
h, w = img.shape[:2]
if w < 1200:
scale = 1200 / w
img_scaled = cv2.resize(img, None, fx=scale, fy=scale,
interpolation=cv2.INTER_CUBIC)
else:
img_scaled = img
result_pp = pp_engine(img_scaled)
for item in result_pp:
if item.get('type') == 'table':
html = item.get('res', {}).get('html', '')
if html:
rows = parse_html_table(html)
if rows:
# Apply domain correction
corrected_rows = []
for row in rows:
corrected = [correct_technical_text(cell) for cell in row]
corrected_rows.append(corrected)
text = "\n".join(" | ".join(r) for r in corrected_rows)
print(f" PPStructure: {len(corrected_rows)} rows")
return {"rows": corrected_rows, "text": text, "html": html}
except Exception as e:
print(f" PPStructure error: {e}")
# Strategy 3: Contour-based cell detection
print(f" Trying contour-based table OCR...")
result = ocr_table_manual(img, img_path, backend)
# Apply domain correction to final result
if result.get("rows"):
corrected_rows = []
for row in result["rows"]:
corrected = [correct_technical_text(cell) for cell in row]
corrected_rows.append(corrected)
result["rows"] = corrected_rows
result["text"] = "\n".join(" | ".join(r) for r in corrected_rows)
return result
def ocr_table_manual(img, img_path, backend="paddle"):
cells = detect_table_structure(img)
if cells:
ocr_results = []
for (x1, y1, x2, y2) in cells:
cell_w, cell_h = x2 - x1, y2 - y1
img_h, img_w = img.shape[:2]
if cell_w > img_w * 0.9 and cell_h > img_h * 0.9:
continue
if cell_w < 15 or cell_h < 15:
continue
pad = 3
cy1 = max(0, y1 - pad)
cx1 = max(0, x1 - pad)
cy2 = min(img.shape[0], y2 + pad)
cx2 = min(img.shape[1], x2 + pad)
cell_img = img[cy1:cy2, cx1:cx2]
text = ocr_cell_improved(cell_img, backend=backend)
if text:
ocr_results.append({
"text": text.strip(),
"x": (x1 + x2) // 2,
"y": (y1 + y2) // 2,
"box": (x1, y1, x2, y2)
})
if ocr_results:
rows = group_rows(ocr_results, vertical_thresh_ratio=0.5)
return {
"rows": rows,
"text": "\n".join(" | ".join(r) for r in rows)
}
return ocr_table_fullimage(img, backend)
_pp_structure = None
def get_pp_structure():
global _pp_structure
if _pp_structure is not None:
return _pp_structure
try:
from paddleocr import PPStructure
print("[INFO] Initializing PPStructure...")
_pp_structure = PPStructure(
table=True, ocr=True, lang='vi',
show_log=False, use_gpu=(DEVICE == "cuda"),
)
return _pp_structure
except Exception as e:
print(f"[WARN] PPStructure init failed: {e}")
return None
def parse_html_table(html_str):
rows = []
tr_pattern = re.findall(r'<tr>(.*?)</tr>', html_str, re.DOTALL)
for tr in tr_pattern:
cells = re.findall(r'<td[^>]*>(.*?)</td>', tr, re.DOTALL)
clean_cells = []
for cell in cells:
clean = re.sub(r'<[^>]+>', '', cell).strip()
clean_cells.append(clean)
if clean_cells:
rows.append(clean_cells)
return rows
def ocr_table_fullimage(img, backend="paddle"):
if backend == "surya":
text = ocr_with_surya(img, langs=["vi", "en"])
lines = [line.strip() for line in text.split("\n") if line.strip()]
rows = [[line] for line in lines]
return {"rows": rows, "text": text}
reader = get_paddle_reader('vi') if backend == "paddle" else get_easyocr_reader()
if reader is None:
reader = get_easyocr_reader()
img_proc = preprocess_for_ocr(img, min_width=1500, mode="table")
items, _ = ocr_single_pass(reader, img_proc)
if not items:
# Try handwriting preprocessing
img_hw = preprocess_for_handwriting(img, min_width=1800)
items, _ = ocr_single_pass(reader, img_hw)
if not items:
return {"rows": [], "text": ""}
# Apply corrections
for item in items:
item["text"] = correct_technical_text(item["text"])
rows = group_rows(items, vertical_thresh_ratio=0.6)
return {"rows": rows, "text": "\n".join(" | ".join(r) for r in rows)}
# ============================================================
# GROUP ROWS
# ============================================================
def group_rows(items, vertical_thresh_ratio=0.6):
if not items:
return []
items_sorted = sorted(items, key=lambda x: x["y"])
y_vals = [it["y"] for it in items_sorted]
if len(y_vals) > 1:
gaps = [y_vals[i+1] - y_vals[i] for i in range(len(y_vals)-1)]
median_gap = np.median(gaps)
thresh = max(8, median_gap * vertical_thresh_ratio)
else:
thresh = 12
rows = []
current_row = [items_sorted[0]]
for it in items_sorted[1:]:
if it["y"] - current_row[-1]["y"] < thresh:
current_row.append(it)
else:
current_row.sort(key=lambda x: x["x"])
rows.append(current_row)
current_row = [it]
current_row.sort(key=lambda x: x["x"])
rows.append(current_row)
return [[it["text"] for it in row] for row in rows]
# ============================================================
# POST-PROCESSING
# ============================================================
def post_process_ocr_text(text):
if not text:
return text
text = re.sub(r'(?<=[0-9])O(?=[0-9])', '0', text)
text = re.sub(r'(?<=M)O', '0', text)
text = re.sub(r'(?<=Ø)O', '0', text)
text = re.sub(r'(?<=[0-9])[lI](?=[0-9])', '1', text)
text = re.sub(r'(\d+)\s*[xX]\s*(\d+)', r'\1×\2', text)
text = re.sub(r'\s+', ' ', text).strip()
# Domain correction
text = correct_technical_text(text)
return text
# ============================================================
# OCR NOTE
# ============================================================
def ocr_note(img_path, backend="paddle"):
img = cv2.imread(img_path)
if img is None:
return ""
if backend == "surya":
text = ocr_with_surya(img, langs=["vi", "en"])
lines = [line.strip() for line in text.split("\n") if line.strip()]
processed = [post_process_ocr_text(t) for t in lines]
return "\n".join(processed)
reader_vi = get_paddle_reader('vi') if backend == "paddle" else None
reader_en = get_paddle_reader('en') if backend == "paddle" else None
if reader_vi is None and reader_en is None:
reader_vi = get_easyocr_reader()
best_items = []
best_conf = 0.0
if reader_vi:
items, conf = multi_pass_ocr(img, reader_vi, "note")
if conf > best_conf:
best_conf = conf
best_items = items
if reader_en:
items, conf = multi_pass_ocr(img, reader_en, "note")
if conf > best_conf:
best_conf = conf
best_items = items
texts = [it["text"] for it in best_items]
processed = [post_process_ocr_text(t) for t in texts]
processed = [t for t in processed if t]
return "\n".join(processed)
# ============================================================
# MAIN PIPELINE
# ============================================================
def run_pipeline(image_path, output_dir="outputs",
checkpoint="best.pt", conf_thresh=0.3,
ocr_backend="paddle"):
image_path = str(image_path)
img_name = Path(image_path).name
stem = Path(image_path).stem
crop_dir = Path(output_dir) / stem / "crops"
crop_dir.mkdir(parents=True, exist_ok=True)
model = get_det_model(checkpoint)
results = model(image_path, imgsz=1024, conf=conf_thresh,
iou=0.5, device=DEVICE, verbose=False)
img_bgr = cv2.imread(image_path)
if img_bgr is None:
raise ValueError(f"Cannot read: {image_path}")
objects = []
for i, box in enumerate(results[0].boxes):
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
cls_idx = int(box.cls[0])
conf_val = round(float(box.conf[0]), 4)
cls_raw = CLASS_NAMES[cls_idx]
cls_show = CLASS_DISPLAY[cls_raw]
pad = 10
crop = img_bgr[max(0, y1-pad):min(img_bgr.shape[0], y2+pad),
max(0, x1-pad):min(img_bgr.shape[1], x2+pad)]
crop_path = str(crop_dir / f"{cls_show}_{i+1}.jpg")
cv2.imwrite(crop_path, crop, [cv2.IMWRITE_JPEG_QUALITY, 98])
ocr_content = None
if cls_raw == "note":
print(f"[OCR] Note #{i+1} ({x2-x1}x{y2-y1}px)...")
ocr_content = ocr_note(crop_path, backend=ocr_backend)
print(f" → {repr(ocr_content[:120]) if ocr_content else 'EMPTY'}")
elif cls_raw == "table":
print(f"[OCR] Table #{i+1} ({x2-x1}x{y2-y1}px)...")
ocr_content = ocr_table(crop_path, backend=ocr_backend)
preview = ocr_content.get("text", "")[:120]
print(f" → {repr(preview) if preview else 'EMPTY'}")
objects.append({
"id": i+1, "class": cls_show,
"confidence": conf_val,
"bbox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
"crop_path": crop_path,
"ocr_content": ocr_content,
})
color = COLORS[cls_raw]
cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 2)
label = f"{cls_show} {conf_val:.2f}"
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
cv2.rectangle(img_bgr, (x1, y1-th-10), (x1+tw+8, y1), color, -1)
cv2.putText(img_bgr, label, (x1+4, y1-4),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
vis_path = str(Path(output_dir) / stem / "result_vis.jpg")
cv2.imwrite(vis_path, img_bgr)
result = {"image": img_name, "objects": objects}
json_path = str(Path(output_dir) / stem / "result.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"[✓] {len(objects)} objects | {vis_path} | {json_path}")
return result, vis_path
if __name__ == "__main__":
import sys
img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg"
backend = sys.argv[2] if len(sys.argv) > 2 else "easyocr"
result, _ = run_pipeline(img, ocr_backend=backend)
print(json.dumps(result, ensure_ascii=False, indent=2))