# src/inference.py import torch _orig_torch_load = torch.load def _patched_load(*args, **kwargs): kwargs.setdefault("weights_only", False) return _orig_torch_load(*args, **kwargs) torch.load = _patched_load import cv2 import json import numpy as np from pathlib import Path from ultralytics import RTDETR import re from difflib import SequenceMatcher DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"[INFO] Device: {DEVICE}") CLASS_NAMES = ["note", "part-drawing", "table"] CLASS_DISPLAY = {"note": "Note", "part-drawing": "PartDrawing", "table": "Table"} COLORS = {"note": (0,165,255), "part-drawing": (0,200,0), "table": (0,0,220)} _det_model = None _ocr_paddle = None _ocr_paddle_en = None _ocr_easyocr = None _ocr_vietocr = None REALESRGAN_AVAILABLE = False _esrgan_upsampler = None # Thêm biến global try: from realesrgan import RealESRGANer from basicsr.archs.rrdbnet_arch import RRDBNet REALESRGAN_AVAILABLE = True print("[INFO] Real-ESRGAN is available") except ImportError: print("[WARN] Real-ESRGAN not installed. Install: pip install realesrgan basicsr") def get_esrgan_upsampler(): global _esrgan_upsampler if not REALESRGAN_AVAILABLE: return None if _esrgan_upsampler is None: try: print("[INFO] Loading Real-ESRGAN model...") model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) _esrgan_upsampler = RealESRGANer( scale=4, model_path='weights/RealESRGAN_x4plus_anime_6B.pth', model=model, device=DEVICE ) except Exception as e: print(f"[WARN] Failed to load Real-ESRGAN: {e}") return None return _esrgan_upsampler def upscale_if_needed(img_bgr, min_dim=300): """Upscale image using Real-ESRGAN if both dimensions are below threshold.""" h, w = img_bgr.shape[:2] if h < min_dim or w < min_dim: upsampler = get_esrgan_upsampler() if upsampler is not None: try: output, _ = upsampler.enhance(img_bgr, outscale=2) return output except Exception as e: print(f"[WARN] ESRGAN upscale failed: {e}") return img_bgr # ============================================================ # DOMAIN DICTIONARY — Từ điển bản vẽ kỹ thuật Việt Nam # ============================================================ # Từ điển các từ thường gặp trong bảng kê bản vẽ kỹ thuật TECH_DICTIONARY = { # Tên chi tiết "bọc táp": "Bọc táp", "boc tap": "Bọc táp", "bạc táp": "Bọc táp", "bọc tốp": "Bọc táp", "bọc tot": "Bọc táp", "vòng đệm": "Vòng đệm", "vong dem": "Vòng đệm", "vòng dệm": "Vòng đệm", "vong đệm": "Vòng đệm", "chốt trụ": "Chốt trụ", "chot tru": "Chốt trụ", "chốt trự": "Chốt trụ", "chot trụ": "Chốt trụ", "vít": "Vít", "vit": "Vít", "bu lông": "Bu lông", "bu long": "Bu lông", "bulong": "Bu lông", "bu-lông": "Bu lông", "bulông": "Bu lông", "vòng đệm vênh": "Vòng đệm vênh", "vong dem venh": "Vòng đệm vênh", "vòng dệm vênh": "Vòng đệm vênh", "then bằng": "Then bằng", "then bang": "Then bằng", "then bảng": "Then bằng", "ống dẫn": "Ống dẫn", "ong dan": "Ống dẫn", "ống chẫn": "Ống dẫn", "ống dần": "Ống dẫn", "ông dẫn": "Ống dẫn", "ông chẫn": "Ống dẫn", "ống chến": "Ống dẫn", "ông chến": "Ống dẫn", "chốt chặn": "Chốt chặn", "chot chan": "Chốt chặn", "chốt chắn": "Chốt chặn", "cnốt chến": "Chốt chặn", "chốt chén": "Chốt chặn", "bạc lót": "Bạc lót", "bac lot": "Bạc lót", "bọc lót": "Bạc lót", "bạc lốt": "Bạc lót", "bọc lết": "Bạc lót", "bọc lết": "Bạc lót", "giá đỡ": "Giá đỡ", "gia do": "Giá đỡ", "giá dở": "Giá đỡ", "giá đở": "Giá đỡ", "bánh răng": "Bánh răng", "banh rang": "Bánh răng", "bành răng": "Bánh răng", "bánh rằng": "Bánh răng", "bảnh răng": "Bánh răng", "bdnh răng": "Bánh răng", "bdình răng": "Bánh răng", "hộp bánh răng": "Hộp bánh răng", "hop banh rang": "Hộp bánh răng", "hộp bành răng": "Hộp bánh răng", "mộp bành răng": "Hộp bánh răng", "mộp bánh răng": "Hộp bánh răng", "nắp": "Nắp", "nap": "Nắp", "năp": "Nắp", "nốp": "Nắp", # Vật liệu "đồng nhôm": "Đồng nhôm", "dong nhom": "Đồng nhôm", "đồng thanh": "Đồng nhôm", "đồng thann": "Đồng nhôm", "đổng nhôm": "Đồng nhôm", "đống nhôm": "Đồng nhôm", "đống thanh": "Đồng nhôm", "thép ct3": "Thép CT3", "thep ct3": "Thép CT3", "thếp ct3": "Thép CT3", "tnép ct3": "Thép CT3", "thếp cts": "Thép CT3", "tnếp ct3": "Thép CT3", "thếp ctj": "Thép CT3", "tnép ctj": "Thép CT3", "tnếp ctj": "Thép CT3", "thep ctj": "Thép CT3", "thép 65": "Thép 65", "thep 65": "Thép 65", "thếp 65": "Thép 65", "tnếp 65": "Thép 65", "thếp 65f": "Thép 65", "tnép 65f": "Thép 65", "thép 6sf": "Thép 65", "thep 6sf": "Thép 65", "thếp 6sf": "Thép 65", "tnép 6sf": "Thép 65", "thép 45": "Thép 45", "thep 45": "Thép 45", "thếp 45": "Thép 45", "tnếp 45": "Thép 45", "sắt tây": "Sắt tây", "sat tay": "Sắt tây", "sắt tay": "Sắt tây", "sdi tay": "Sắt tây", "sdi day": "Sắt tây", "sdi đay": "Sắt tây", "gang 15-32": "Gang 15-32", "gang15-32": "Gang 15-32", "gong 15-32": "Gang 15-32", "gong15-32": "Gang 15-32", "gang 15.32": "Gang 15-32", "gang 15 32": "Gang 15-32", "gong 15.32": "Gang 15-32", "gang1532": "Gang 15-32", # Header "vị trí": "Vị trí", "vi tri": "Vị trí", "v.trí": "Vị trí", "tên chi tiết": "Tên chi tiết", "ten chi tiet": "Tên chi tiết", "tên chi tiết máy": "Tên chi tiết máy", "ten chi tiet may": "Tên chi tiết máy", "số lg": "Số lg", "so lg": "Số lg", "số lượng": "Số lg", "so luong": "Số lg", "s.lg": "Số lg", "số lý": "Số lg", "vật liệu": "Vật liệu", "vat lieu": "Vật liệu", "vat liéu": "Vật liệu", "ghi chú": "Ghi chú", "ghi chu": "Ghi chú", # Title block "bản vẽ số": "Bản vẽ số", "ban ve so": "Bản vẽ số", "bản gối": "Bản vẽ số", "bơm bánh răng": "BƠM BÁNH RĂNG", "bom banh rang": "BƠM BÁNH RĂNG", "bớm bánh răng": "BƠM BÁNH RĂNG", "bản vẽ lắp số": "Bản vẽ lắp số", "ban ve lap so": "Bản vẽ lắp số", "bản vể lắp số": "Bản vẽ lắp số", "bán vẽ lắp số": "Bản vẽ lắp số", "bán vể lắp số": "Bản vẽ lắp số", "tỷ lệ": "Tỷ lệ", "ty le": "Tỷ lệ", "tý lệ": "Tỷ lệ", "bộ môn hình hoạ": "Bộ môn Hình hoạ", "bộ môn hình họa": "Bộ môn Hình hoạ", "bo mon hinh hoa": "Bộ môn Hình hoạ", "bộ mốn hình hoạ": "Bộ môn Hình hoạ", "đại học bách khoa hà nội": "Đại học Bách khoa Hà Nội", "dai hoc bach khoa ha noi": "Đại học Bách khoa Hà Nội", "đại học bách khoa": "Đại học Bách khoa Hà Nội", "bại hoc bách khoa": "Đại học Bách khoa Hà Nội", "bại học bách khoa hà nội": "Đại học Bách khoa Hà Nội", } # Canonical part names for fuzzy matching CANONICAL_PARTS = [ "Bọc táp", "Vòng đệm", "Chốt trụ", "Vít", "Bu lông", "Vòng đệm vênh", "Then bằng", "Ống dẫn", "Chốt chặn", "Bạc lót", "Giá đỡ", "Bánh răng", "Hộp bánh răng", "Nắp", ] CANONICAL_MATERIALS = [ "Đồng nhôm", "Thép CT3", "Thép 65", "Thép 45", "Sắt tây", "Gang 15-32", ] CANONICAL_HEADERS = [ "Vị trí", "Tên chi tiết", "Tên chi tiết máy", "Số lg", "Vật liệu", "Ghi chú", ] def fuzzy_match(text, candidates, threshold=0.55): """Fuzzy match text against candidates, return best match if above threshold.""" if not text or not candidates: return text text_lower = text.lower().strip() # Exact match in dictionary first if text_lower in TECH_DICTIONARY: return TECH_DICTIONARY[text_lower] # Fuzzy match best_match = None best_score = 0 for candidate in candidates: score = SequenceMatcher(None, text_lower, candidate.lower()).ratio() if score > best_score: best_score = score best_match = candidate if best_score >= threshold: return best_match return text def correct_technical_text(text, column_type="auto"): """ Sửa lỗi OCR dựa trên domain knowledge. column_type: "position", "name", "quantity", "material", "note", "auto" """ if not text or not text.strip(): return text original = text.strip() text_lower = original.lower() # 1. Exact dictionary lookup if text_lower in TECH_DICTIONARY: return TECH_DICTIONARY[text_lower] # 2. Column-specific corrections if column_type == "position" or (column_type == "auto" and original.replace('.','').replace(',','').isdigit()): # Position column — should be a number cleaned = re.sub(r'[^0-9]', '', original) if cleaned: return cleaned return original if column_type == "quantity" or (column_type == "auto" and len(original) <= 2 and any(c.isdigit() for c in original)): cleaned = re.sub(r'[^0-9]', '', original) if cleaned: return cleaned return original if column_type == "name": # Try fuzzy match against known part names result = fuzzy_match(original, CANONICAL_PARTS, threshold=0.5) if result != original: return result # Also check headers result = fuzzy_match(original, CANONICAL_HEADERS, threshold=0.5) if result != original: return result if column_type == "material": result = fuzzy_match(original, CANONICAL_MATERIALS, threshold=0.5) if result != original: return result if column_type == "auto": # Try all categories for candidates in [CANONICAL_PARTS, CANONICAL_MATERIALS, CANONICAL_HEADERS]: result = fuzzy_match(original, candidates, threshold=0.55) if result != original: return result # 3. General corrections text_out = original # Fix common OCR character substitutions # M followed by digits (bolt/screw specs) text_out = re.sub(r'[Mm]\s*(\d)', r'M\1', text_out) # Fix: "5x8-35" style dimensions text_out = re.sub(r'(\d+)\s*[xX×]\s*(\d+)\s*[-–]\s*(\d+)', r'\1x\2-\3', text_out) text_out = re.sub(r'(\d+)\s*[xX×]\s*(\d+)\s*[xX×]\s*(\d+)', r'\1x\2x\3', text_out) # Fix: "3n8.35" → "5x8-35" (common OCR error for handwriting) text_out = re.sub(r'(\d+)\s*n\s*(\d+)', r'\1x\2', text_out) # Fix: dimension specs like "4x6x14" or "4*6*14" text_out = re.sub(r'(\d+)\s*[*]\s*(\d+)\s*[*]\s*(\d+)', r'\1x\2x\3', text_out) return text_out def correct_table_row(row, num_columns=5): """ Sửa lỗi cho toàn bộ 1 row, biết vị trí cột. Columns: [Vị trí, Tên chi tiết, Số lg, Vật liệu, Ghi chú] """ if not row: return row corrected = list(row) # Pad to expected columns while len(corrected) < num_columns: corrected.append("") # Trim excess if len(corrected) > num_columns: corrected = corrected[:num_columns] # Column 0: Vị trí (number) if corrected[0]: corrected[0] = correct_technical_text(corrected[0], "position") # Column 1: Tên chi tiết (part name) if corrected[1]: corrected[1] = correct_technical_text(corrected[1], "name") # Column 2: Số lg (quantity - number) if corrected[2]: corrected[2] = correct_technical_text(corrected[2], "quantity") # Column 3: Vật liệu (material) if corrected[3]: corrected[3] = correct_technical_text(corrected[3], "material") # Column 4: Ghi chú (note - keep as-is mostly) if corrected[4]: corrected[4] = correct_technical_text(corrected[4], "auto") return corrected # ============================================================ # MODEL LOADERS # ============================================================ def get_det_model(checkpoint="best.pt"): global _det_model if _det_model is None: print(f"[INFO] Loading detection model: {checkpoint}") _det_model = RTDETR(checkpoint) return _det_model # ============================================================ # SURYA OCR (optional) # ============================================================ SURYA_AVAILABLE = False try: from surya.ocr import run_ocr from surya.model.detection.model import load_det_processor, load_det_model from surya.model.recognition.model import load_rec_model from surya.model.recognition.processor import load_rec_processor SURYA_AVAILABLE = True print("[INFO] Surya OCR is available") except ImportError: print("[WARN] Surya OCR not installed. Install with: pip install surya-ocr") def ocr_with_surya(img_bgr, langs=["vi", "en"]): if not SURYA_AVAILABLE: raise ImportError("Surya OCR is not installed.") det_processor, det_model = load_det_processor(), load_det_model() rec_model, rec_processor = load_rec_model(), load_rec_processor() from PIL import Image pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) predictions = run_ocr([pil_img], [langs], det_model, det_processor, rec_model, rec_processor) texts = [line.text for line in predictions[0].text_lines] return "\n".join(texts) # ============================================================ # VietOCR (optional - tốt cho chữ viết tay tiếng Việt) # ============================================================ VIETOCR_AVAILABLE = False try: from vietocr.tool.predictor import Predictor from vietocr.tool.config import Cfg VIETOCR_AVAILABLE = True print("[INFO] VietOCR is available") except ImportError: print("[WARN] VietOCR not installed. Install with: pip install vietocr") def get_vietocr(): global _ocr_vietocr if _ocr_vietocr is None and VIETOCR_AVAILABLE: try: config = Cfg.load_config_from_name('vgg_transformer') config['cnn']['pretrained'] = True config['device'] = DEVICE _ocr_vietocr = Predictor(config) print("[INFO] VietOCR loaded successfully") except Exception as e: print(f"[WARN] VietOCR load failed: {e}") return _ocr_vietocr def ocr_line_vietocr(img_bgr): """OCR a single text line image using VietOCR.""" predictor = get_vietocr() if predictor is None: return "" from PIL import Image pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)) text = predictor.predict(pil_img) return text.strip() # ============================================================ # PaddleOCR / EasyOCR # ============================================================ def get_paddle_reader(lang='vi'): global _ocr_paddle, _ocr_paddle_en if lang == 'en': if _ocr_paddle_en is not None: return _ocr_paddle_en else: if _ocr_paddle is not None: return _ocr_paddle try: from paddleocr import PaddleOCR print(f"[INFO] Initializing PaddleOCR PP-OCRv4 (lang={lang})...") reader = PaddleOCR( lang=lang, use_angle_cls=True, use_gpu=(DEVICE == "cuda"), show_log=False, ocr_version='PP-OCRv4', det_db_thresh=0.15, det_db_box_thresh=0.2, det_db_unclip_ratio=2.0, use_dilation=True, det_db_score_mode='slow', rec_image_shape="3,48,320", max_text_length=80, rec_batch_num=6, ) if lang == 'en': _ocr_paddle_en = reader else: _ocr_paddle = reader return reader except Exception as e: print(f"[WARN] PaddleOCR init failed: {e}") return None def get_easyocr_reader(): global _ocr_easyocr if _ocr_easyocr is None: import easyocr _ocr_easyocr = easyocr.Reader( ["vi", "en"], gpu=(DEVICE == "cuda"), verbose=False ) return _ocr_easyocr # ============================================================ # PREPROCESSING # ============================================================ def enhance_faded_text(img_bgr): """Giải pháp 4: Unsharp Masking kết hợp Local Thresholding cho nét chữ mờ""" gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) # 1. Unsharp Masking (Tăng cường cạnh/nét chữ) gaussian = cv2.GaussianBlur(gray, (0, 0), 2.0) unsharp = cv2.addWeighted(gray, 1.5, gaussian, -0.5, 0) # 2. Ngưỡng cục bộ (Local Thresholding) try: from skimage.filters import threshold_sauvola window_size = 25 thresh = threshold_sauvola(unsharp, window_size=window_size) binary = (unsharp > thresh) * 255 binary = binary.astype(np.uint8) except ImportError: # Fallback về OpenCV nếu chưa cài scikit-image binary = cv2.adaptiveThreshold(unsharp, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 21, 10) return cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR) def preprocess_for_ocr(img_bgr, min_width=1500, mode="note"): h, w = img_bgr.shape[:2] if w < min_width: scale = min_width / w img_bgr = cv2.resize(img_bgr, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) h, w = img_bgr.shape[:2] if mode == "note": img_proc = cv2.bilateralFilter(img_bgr, 9, 75, 75) lab = cv2.cvtColor(img_proc, cv2.COLOR_BGR2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) l = clahe.apply(l) lab = cv2.merge([l, a, b]) img_proc = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) kernel = np.array([[0, -0.5, 0], [-0.5, 3, -0.5], [0, -0.5, 0]]) img_proc = cv2.filter2D(img_proc, -1, kernel) return img_proc else: # table img_proc = cv2.bilateralFilter(img_bgr, 11, 80, 80) lab = cv2.cvtColor(img_proc, cv2.COLOR_BGR2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(4, 4)) l = clahe.apply(l) lab = cv2.merge([l, a, b]) img_proc = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) return img_proc def preprocess_for_handwriting(img_bgr, min_width=1800): """ Tiền xử lý đặc biệt cho chữ viết tay. Tăng contrast mạnh, loại bỏ đường kẻ bảng, giữ nét chữ. """ h, w = img_bgr.shape[:2] if w < min_width: scale = min_width / w img_bgr = cv2.resize(img_bgr, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) # Convert to grayscale gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) # Remove horizontal and vertical lines (table borders) h_img, w_img = gray.shape # Detect and remove horizontal lines h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (max(40, w_img // 10), 1)) h_lines = cv2.morphologyEx(~gray, cv2.MORPH_OPEN, h_kernel, iterations=1) gray_no_lines = gray.copy() gray_no_lines[h_lines > 128] = 255 # Detect and remove vertical lines v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, max(40, h_img // 10))) v_lines = cv2.morphologyEx(~gray, cv2.MORPH_OPEN, v_kernel, iterations=1) gray_no_lines[v_lines > 128] = 255 # CLAHE for better contrast clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) enhanced = clahe.apply(gray_no_lines) # Adaptive threshold — good for handwriting binary = cv2.adaptiveThreshold(enhanced, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 21, 10) # Light morphological cleaning kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2)) binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=1) return cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR) def preprocess_grayscale_variant(img_bgr, min_width=1500): h, w = img_bgr.shape[:2] if w < min_width: scale = min_width / w img_bgr = cv2.resize(img_bgr, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) gray = clahe.apply(gray) _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) return cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR) # ============================================================ # OCR FUNCTIONS # ============================================================ def ocr_single_pass(reader, img_bgr): """Run OCR once, return (list_of_dicts, avg_confidence).""" if hasattr(reader, 'ocr'): # PaddleOCR result = reader.ocr(img_bgr, cls=True) if not result or not result[0]: return [], 0.0 items = [] confs = [] for line in result[0]: box, (text, conf) = line if conf >= 0.15 and text.strip(): xs = [p[0] for p in box] ys = [p[1] for p in box] items.append({ "text": text.strip(), "conf": conf, "x": np.mean(xs), "y": np.mean(ys), "x1": min(xs), "y1": min(ys), "x2": max(xs), "y2": max(ys), "box": box }) confs.append(conf) avg_conf = np.mean(confs) if confs else 0.0 return items, avg_conf else: # EasyOCR results = reader.readtext(img_bgr, detail=1, paragraph=False) items = [] confs = [] for (pts, text, conf) in results: if conf >= 0.1 and text.strip(): xs = [p[0] for p in pts] ys = [p[1] for p in pts] items.append({ "text": text.strip(), "conf": conf, "x": np.mean(xs), "y": np.mean(ys), "x1": min(xs), "y1": min(ys), "x2": max(xs), "y2": max(ys), "box": pts }) confs.append(conf) avg_conf = np.mean(confs) if confs else 0.0 return items, avg_conf def multi_pass_ocr(img_bgr, reader, ocr_type="note"): """Multi-pass OCR with different preprocessings.""" best_items = [] best_conf = 0.0 # [NẾU ẢNH NHỎ LÀ DO CẮT TỪ GÓC, ÚP SCALE LUÔN TRƯỚC KHI LÀM GÌ ĐÓ] img_bgr = upscale_if_needed(img_bgr, min_dim=400) # Pass 1: Color preprocessing img_v1 = preprocess_for_ocr(img_bgr, min_width=1500, mode=ocr_type) items1, conf1 = ocr_single_pass(reader, img_v1) if conf1 > best_conf: best_conf = conf1 best_items = items1 # Pass 2: Handwriting-optimized preprocessing img_v2 = preprocess_for_handwriting(img_bgr, min_width=1800) items2, conf2 = ocr_single_pass(reader, img_v2) if conf2 > best_conf: best_conf = conf2 best_items = items2 # Pass 3: Extra upscale img_v3 = preprocess_for_ocr(img_bgr, min_width=2500, mode=ocr_type) items3, conf3 = ocr_single_pass(reader, img_v3) if conf3 > best_conf: best_conf = conf3 best_items = items3 # Pass 4: Grayscale Otsu img_v4 = preprocess_grayscale_variant(img_bgr, min_width=1500) items4, conf4 = ocr_single_pass(reader, img_v4) if conf4 > best_conf: best_conf = conf4 best_items = items4 # --- THÊM PASS 5: Giải quyết chữ bị mờ, lợt --- img_v5 = enhance_faded_text(img_bgr) items5, conf5 = ocr_single_pass(reader, img_v5) if conf5 > best_conf: best_conf = conf5 best_items = items5 print(f" Multi-pass confidences: {conf1:.3f}, {conf2:.3f}, {conf3:.3f}, {conf4:.3f}, {conf5:.3f} → best={best_conf:.3f}") return best_items, best_conf # ============================================================ # TABLE STRUCTURE — Intersection-based cell detection # ============================================================ def detect_lines(gray, direction="horizontal", min_length_ratio=0.15): """Detect lines in image.""" h, w = gray.shape _, binary = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV) if direction == "horizontal": kernel_len = max(30, int(w * min_length_ratio)) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_len, 1)) else: kernel_len = max(30, int(h * min_length_ratio)) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, kernel_len)) lines_img = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=2) # Dilate slightly to connect broken lines dilate_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) lines_img = cv2.dilate(lines_img, dilate_kernel, iterations=1) return lines_img def find_line_positions(lines_img, direction="horizontal", merge_distance=10): """Find positions of lines (y-coords for horizontal, x-coords for vertical).""" if direction == "horizontal": projection = np.sum(lines_img, axis=1) else: projection = np.sum(lines_img, axis=0) # Find peaks threshold = np.max(projection) * 0.3 positions = np.where(projection > threshold)[0] if len(positions) == 0: return [] # Merge close positions merged = [positions[0]] for pos in positions[1:]: if pos - merged[-1] > merge_distance: merged.append(pos) else: # Take average merged[-1] = (merged[-1] + pos) // 2 return merged def detect_table_cells_by_intersection(img_bgr): """ Detect table cells by finding intersections of horizontal and vertical lines. Returns list of cells as (x1, y1, x2, y2) tuples, organized in grid. """ gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) h, w = gray.shape # Detect horizontal and vertical lines h_lines = detect_lines(gray, "horizontal", min_length_ratio=0.1) v_lines = detect_lines(gray, "vertical", min_length_ratio=0.1) # Find line positions y_positions = find_line_positions(h_lines, "horizontal", merge_distance=max(8, h//50)) x_positions = find_line_positions(v_lines, "vertical", merge_distance=max(8, w//50)) print(f" Table grid: {len(y_positions)} horizontal × {len(x_positions)} vertical lines") if len(y_positions) < 2 or len(x_positions) < 2: # Fallback to contour-based detection return detect_table_structure(img_bgr), None # Generate cells from grid intersections cells = [] grid = [] for i in range(len(y_positions) - 1): row_cells = [] for j in range(len(x_positions) - 1): x1, y1 = x_positions[j], y_positions[i] x2, y2 = x_positions[j + 1], y_positions[i + 1] # Filter tiny cells if (x2 - x1) < 10 or (y2 - y1) < 10: continue cells.append((x1, y1, x2, y2)) row_cells.append((x1, y1, x2, y2)) if row_cells: grid.append(row_cells) return cells, grid def detect_table_structure(img_bgr): """Fallback contour-based cell detection.""" h, w = img_bgr.shape[:2] gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) _, binary = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV) h_kernel_len = max(40, w // 15) v_kernel_len = max(40, h // 15) horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (h_kernel_len, 1)) horizontal_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, horizontal_kernel, iterations=2) vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, v_kernel_len)) vertical_lines = cv2.morphologyEx(binary, cv2.MORPH_OPEN, vertical_kernel, iterations=2) table_structure = cv2.add(horizontal_lines, vertical_lines) contours, _ = cv2.findContours(table_structure, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) cells = [] min_cell_area = (w * h) * 0.001 max_cell_area = (w * h) * 0.85 for cnt in contours: x, y, cw, ch = cv2.boundingRect(cnt) area = cw * ch if min_cell_area < area < max_cell_area and cw > 15 and ch > 15: cells.append((x, y, x + cw, y + ch)) cells = sorted(set(cells), key=lambda r: (r[1], r[0])) return cells # ============================================================ # OCR TABLE — Grid-based approach # ============================================================ def ocr_cell_improved(img_cell, backend="paddle"): if img_cell is None or img_cell.size == 0: return "" # Upscale very small cells with ESRGAN img_cell = upscale_if_needed(img_cell, min_dim=150) h, w = img_cell.shape[:2] if h < 5 or w < 5: return "" # Upscale small cells target_h = max(64, h) if h < target_h: scale = target_h / h img_cell = cv2.resize(img_cell, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) target_w = max(200, w) if w < target_w: scale_w = target_w / w if scale_w > 1: img_cell = cv2.resize(img_cell, None, fx=scale_w, fy=scale_w, interpolation=cv2.INTER_CUBIC) best_text = "" best_conf = 0 # Try VietOCR first (better for handwriting) if VIETOCR_AVAILABLE: try: vietocr_text = ocr_line_vietocr(img_cell) if vietocr_text: best_text = vietocr_text best_conf = 0.7 # Default confidence for VietOCR except Exception as e: pass # Try PaddleOCR / EasyOCR if backend == "paddle": reader = get_paddle_reader('vi') elif backend == "surya": text = ocr_with_surya(img_cell, langs=["vi", "en"]) if text.strip(): return text.strip() reader = get_paddle_reader('vi') else: reader = get_easyocr_reader() if reader is None: reader = get_easyocr_reader() if reader is None: return best_text # Variant 1: Color with CLAHE img_proc1 = cv2.bilateralFilter(img_cell, 5, 50, 50) lab = cv2.cvtColor(img_proc1, cv2.COLOR_BGR2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(4, 4)) l = clahe.apply(l) lab = cv2.merge([l, a, b]) img_proc1 = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) items1, conf1 = ocr_single_pass(reader, img_proc1) text1 = " ".join([it["text"] for it in items1]) if conf1 > best_conf and text1.strip(): best_conf = conf1 best_text = text1 # Variant 2: Handwriting preprocessing (remove lines) img_proc2 = preprocess_for_handwriting(img_cell, min_width=300) items2, conf2 = ocr_single_pass(reader, img_proc2) text2 = " ".join([it["text"] for it in items2]) if conf2 > best_conf and text2.strip(): best_conf = conf2 best_text = text2 # Variant 3: Binary Otsu gray = cv2.cvtColor(img_cell, cv2.COLOR_BGR2GRAY) _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) img_proc3 = cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR) items3, conf3 = ocr_single_pass(reader, img_proc3) text3 = " ".join([it["text"] for it in items3]) if conf3 > best_conf and text3.strip(): best_conf = conf3 best_text = text3 # --- THÊM VARIANT 4: Dành cho nét chữ viết tay bị mờ/đứt nét --- img_proc4 = enhance_faded_text(img_cell) items4, conf4 = ocr_single_pass(reader, img_proc4) text4 = " ".join([it["text"] for it in items4]) if conf4 > best_conf and text4.strip(): best_conf = conf4 best_text = text4 # Also try English PaddleOCR for specs like "M6x50", "CT3" # Also try English PaddleOCR for specs like "M6x50", "CT3" if backend == "paddle": reader_en = get_paddle_reader('en') if reader_en: items_en, conf_en = ocr_single_pass(reader_en, img_proc1) text_en = " ".join([it["text"] for it in items_en]) if conf_en > best_conf and text_en.strip(): # Only prefer English if it looks like specs/numbers if re.search(r'[A-Z0-9]', text_en): best_conf = conf_en best_text = text_en return best_text def ocr_table_grid(img, backend="paddle"): """ OCR table using grid-based cell detection. Key improvement: detect grid structure first, then OCR each cell. """ result = detect_table_cells_by_intersection(img) if isinstance(result, tuple): cells, grid = result else: cells = result grid = None if grid and len(grid) > 0: print(f" Grid detected: {len(grid)} rows") # OCR each cell in grid order all_rows = [] for row_idx, row_cells in enumerate(grid): row_texts = [] for cell_box in row_cells: x1, y1, x2, y2 = cell_box # Extract cell with padding pad = 3 cy1 = max(0, y1 + pad) # +pad to skip the line itself cx1 = max(0, x1 + pad) cy2 = min(img.shape[0], y2 - pad) cx2 = min(img.shape[1], x2 - pad) if cy2 <= cy1 or cx2 <= cx1: row_texts.append("") continue cell_img = img[cy1:cy2, cx1:cx2] text = ocr_cell_improved(cell_img, backend=backend) row_texts.append(text.strip()) if any(t for t in row_texts): # Skip empty rows all_rows.append(row_texts) if all_rows: # Determine number of columns (use most common column count) col_counts = [len(r) for r in all_rows] if col_counts: expected_cols = max(set(col_counts), key=col_counts.count) # Normalize rows to same column count normalized_rows = [] for row in all_rows: if len(row) < expected_cols: row = row + [""] * (expected_cols - len(row)) elif len(row) > expected_cols: row = row[:expected_cols] normalized_rows.append(row) # Apply domain correction corrected_rows = [] for row in normalized_rows: if expected_cols >= 4: corrected = correct_table_row(row, num_columns=expected_cols) else: corrected = [correct_technical_text(cell) for cell in row] corrected_rows.append(corrected) text = "\n".join(" | ".join(r) for r in corrected_rows) return {"rows": corrected_rows, "text": text} # Fallback if grid detection failed return None def ocr_table(img_path, backend="paddle"): img = cv2.imread(img_path) if img is None: return {"rows": [], "text": ""} # Strategy 1: Grid-based cell detection + OCR print(f" Trying grid-based table OCR...") result = ocr_table_grid(img, backend) if result and result.get("rows"): print(f" Grid OCR: {len(result['rows'])} rows") return result # Strategy 2: PPStructure (if paddle backend) if backend == "paddle": pp_engine = get_pp_structure() if pp_engine is not None: try: h, w = img.shape[:2] if w < 1200: scale = 1200 / w img_scaled = cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) else: img_scaled = img result_pp = pp_engine(img_scaled) for item in result_pp: if item.get('type') == 'table': html = item.get('res', {}).get('html', '') if html: rows = parse_html_table(html) if rows: # Apply domain correction corrected_rows = [] for row in rows: corrected = [correct_technical_text(cell) for cell in row] corrected_rows.append(corrected) text = "\n".join(" | ".join(r) for r in corrected_rows) print(f" PPStructure: {len(corrected_rows)} rows") return {"rows": corrected_rows, "text": text, "html": html} except Exception as e: print(f" PPStructure error: {e}") # Strategy 3: Contour-based cell detection print(f" Trying contour-based table OCR...") result = ocr_table_manual(img, img_path, backend) # Apply domain correction to final result if result.get("rows"): corrected_rows = [] for row in result["rows"]: corrected = [correct_technical_text(cell) for cell in row] corrected_rows.append(corrected) result["rows"] = corrected_rows result["text"] = "\n".join(" | ".join(r) for r in corrected_rows) return result def ocr_table_manual(img, img_path, backend="paddle"): cells = detect_table_structure(img) if cells: ocr_results = [] for (x1, y1, x2, y2) in cells: cell_w, cell_h = x2 - x1, y2 - y1 img_h, img_w = img.shape[:2] if cell_w > img_w * 0.9 and cell_h > img_h * 0.9: continue if cell_w < 15 or cell_h < 15: continue pad = 3 cy1 = max(0, y1 - pad) cx1 = max(0, x1 - pad) cy2 = min(img.shape[0], y2 + pad) cx2 = min(img.shape[1], x2 + pad) cell_img = img[cy1:cy2, cx1:cx2] text = ocr_cell_improved(cell_img, backend=backend) if text: ocr_results.append({ "text": text.strip(), "x": (x1 + x2) // 2, "y": (y1 + y2) // 2, "box": (x1, y1, x2, y2) }) if ocr_results: rows = group_rows(ocr_results, vertical_thresh_ratio=0.5) return { "rows": rows, "text": "\n".join(" | ".join(r) for r in rows) } return ocr_table_fullimage(img, backend) _pp_structure = None def get_pp_structure(): global _pp_structure if _pp_structure is not None: return _pp_structure try: from paddleocr import PPStructure print("[INFO] Initializing PPStructure...") _pp_structure = PPStructure( table=True, ocr=True, lang='vi', show_log=False, use_gpu=(DEVICE == "cuda"), ) return _pp_structure except Exception as e: print(f"[WARN] PPStructure init failed: {e}") return None def parse_html_table(html_str): rows = [] tr_pattern = re.findall(r'(.*?)', html_str, re.DOTALL) for tr in tr_pattern: cells = re.findall(r']*>(.*?)', tr, re.DOTALL) clean_cells = [] for cell in cells: clean = re.sub(r'<[^>]+>', '', cell).strip() clean_cells.append(clean) if clean_cells: rows.append(clean_cells) return rows def ocr_table_fullimage(img, backend="paddle"): if backend == "surya": text = ocr_with_surya(img, langs=["vi", "en"]) lines = [line.strip() for line in text.split("\n") if line.strip()] rows = [[line] for line in lines] return {"rows": rows, "text": text} reader = get_paddle_reader('vi') if backend == "paddle" else get_easyocr_reader() if reader is None: reader = get_easyocr_reader() img_proc = preprocess_for_ocr(img, min_width=1500, mode="table") items, _ = ocr_single_pass(reader, img_proc) if not items: # Try handwriting preprocessing img_hw = preprocess_for_handwriting(img, min_width=1800) items, _ = ocr_single_pass(reader, img_hw) if not items: return {"rows": [], "text": ""} # Apply corrections for item in items: item["text"] = correct_technical_text(item["text"]) rows = group_rows(items, vertical_thresh_ratio=0.6) return {"rows": rows, "text": "\n".join(" | ".join(r) for r in rows)} # ============================================================ # GROUP ROWS # ============================================================ def group_rows(items, vertical_thresh_ratio=0.6): if not items: return [] items_sorted = sorted(items, key=lambda x: x["y"]) y_vals = [it["y"] for it in items_sorted] if len(y_vals) > 1: gaps = [y_vals[i+1] - y_vals[i] for i in range(len(y_vals)-1)] median_gap = np.median(gaps) thresh = max(8, median_gap * vertical_thresh_ratio) else: thresh = 12 rows = [] current_row = [items_sorted[0]] for it in items_sorted[1:]: if it["y"] - current_row[-1]["y"] < thresh: current_row.append(it) else: current_row.sort(key=lambda x: x["x"]) rows.append(current_row) current_row = [it] current_row.sort(key=lambda x: x["x"]) rows.append(current_row) return [[it["text"] for it in row] for row in rows] # ============================================================ # POST-PROCESSING # ============================================================ def post_process_ocr_text(text): if not text: return text text = re.sub(r'(?<=[0-9])O(?=[0-9])', '0', text) text = re.sub(r'(?<=M)O', '0', text) text = re.sub(r'(?<=Ø)O', '0', text) text = re.sub(r'(?<=[0-9])[lI](?=[0-9])', '1', text) text = re.sub(r'(\d+)\s*[xX]\s*(\d+)', r'\1×\2', text) text = re.sub(r'\s+', ' ', text).strip() # Domain correction text = correct_technical_text(text) return text # ============================================================ # OCR NOTE # ============================================================ def ocr_note(img_path, backend="paddle"): img = cv2.imread(img_path) if img is None: return "" if backend == "surya": text = ocr_with_surya(img, langs=["vi", "en"]) lines = [line.strip() for line in text.split("\n") if line.strip()] processed = [post_process_ocr_text(t) for t in lines] return "\n".join(processed) reader_vi = get_paddle_reader('vi') if backend == "paddle" else None reader_en = get_paddle_reader('en') if backend == "paddle" else None if reader_vi is None and reader_en is None: reader_vi = get_easyocr_reader() best_items = [] best_conf = 0.0 if reader_vi: items, conf = multi_pass_ocr(img, reader_vi, "note") if conf > best_conf: best_conf = conf best_items = items if reader_en: items, conf = multi_pass_ocr(img, reader_en, "note") if conf > best_conf: best_conf = conf best_items = items texts = [it["text"] for it in best_items] processed = [post_process_ocr_text(t) for t in texts] processed = [t for t in processed if t] return "\n".join(processed) # ============================================================ # MAIN PIPELINE # ============================================================ def run_pipeline(image_path, output_dir="outputs", checkpoint="best.pt", conf_thresh=0.3, ocr_backend="paddle"): image_path = str(image_path) img_name = Path(image_path).name stem = Path(image_path).stem crop_dir = Path(output_dir) / stem / "crops" crop_dir.mkdir(parents=True, exist_ok=True) model = get_det_model(checkpoint) results = model(image_path, imgsz=1024, conf=conf_thresh, iou=0.5, device=DEVICE, verbose=False) img_bgr = cv2.imread(image_path) if img_bgr is None: raise ValueError(f"Cannot read: {image_path}") objects = [] for i, box in enumerate(results[0].boxes): x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) cls_idx = int(box.cls[0]) conf_val = round(float(box.conf[0]), 4) cls_raw = CLASS_NAMES[cls_idx] cls_show = CLASS_DISPLAY[cls_raw] pad = 10 crop = img_bgr[max(0, y1-pad):min(img_bgr.shape[0], y2+pad), max(0, x1-pad):min(img_bgr.shape[1], x2+pad)] crop_path = str(crop_dir / f"{cls_show}_{i+1}.jpg") cv2.imwrite(crop_path, crop, [cv2.IMWRITE_JPEG_QUALITY, 98]) ocr_content = None if cls_raw == "note": print(f"[OCR] Note #{i+1} ({x2-x1}x{y2-y1}px)...") ocr_content = ocr_note(crop_path, backend=ocr_backend) print(f" → {repr(ocr_content[:120]) if ocr_content else 'EMPTY'}") elif cls_raw == "table": print(f"[OCR] Table #{i+1} ({x2-x1}x{y2-y1}px)...") ocr_content = ocr_table(crop_path, backend=ocr_backend) preview = ocr_content.get("text", "")[:120] print(f" → {repr(preview) if preview else 'EMPTY'}") objects.append({ "id": i+1, "class": cls_show, "confidence": conf_val, "bbox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2}, "crop_path": crop_path, "ocr_content": ocr_content, }) color = COLORS[cls_raw] cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 2) label = f"{cls_show} {conf_val:.2f}" (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) cv2.rectangle(img_bgr, (x1, y1-th-10), (x1+tw+8, y1), color, -1) cv2.putText(img_bgr, label, (x1+4, y1-4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) vis_path = str(Path(output_dir) / stem / "result_vis.jpg") cv2.imwrite(vis_path, img_bgr) result = {"image": img_name, "objects": objects} json_path = str(Path(output_dir) / stem / "result.json") with open(json_path, "w", encoding="utf-8") as f: json.dump(result, f, ensure_ascii=False, indent=2) print(f"[✓] {len(objects)} objects | {vis_path} | {json_path}") return result, vis_path if __name__ == "__main__": import sys img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg" backend = sys.argv[2] if len(sys.argv) > 2 else "easyocr" result, _ = run_pipeline(img, ocr_backend=backend) print(json.dumps(result, ensure_ascii=False, indent=2))