""" Bengali OCR — Full Pipeline Detection (YOLOv8) + Recognition (BengaliCRNN) Usage: from pipeline import BengaliDocOCR ocr = BengaliDocOCR.from_hub() result = ocr.read_document("page.jpg") print(result["text"]) """ import json, os, torch from pathlib import Path from PIL import Image from torchvision import transforms from huggingface_hub import hf_hub_download DETECT_REPO = "Sarjinkhan2003/bengali-ocr-detection" RECOG_REPO = "Sarjinkhan2003/bengali-ocr-recognition" class BengaliDocOCR: """ Full Bengali document OCR pipeline. Combines: - YOLOv8n text detection - LightCRNN text recognition """ def __init__(self, det_model, rec_model, idx2char, img_h=64, img_w=256, device="cpu"): self.det = det_model self.rec = rec_model.to(device).eval() self.idx2char= idx2char self.device = device self.img_h = img_h self.img_w = img_w self.tf = transforms.Compose([ transforms.Grayscale(1), transforms.Resize((img_h, img_w)), transforms.ToTensor(), transforms.Normalize([0.5],[0.5]) ]) @classmethod def from_hub(cls, device="cpu"): """Download both models from HuggingFace and build pipeline.""" from ultralytics import YOLO import importlib.util # Detection model det_path = hf_hub_download(DETECT_REPO, "bengali_det.pt") det_model = YOLO(det_path) # Recognition model net_path = hf_hub_download(RECOG_REPO, "bengali_crnn.py") ckpt_path = hf_hub_download(RECOG_REPO, "bengali_crnn.pth") vocab_path = hf_hub_download(RECOG_REPO, "vocab.json") spec = importlib.util.spec_from_file_location("bengali_crnn", net_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) vocab = json.load(open(vocab_path, encoding="utf-8")) idx2char = {int(k): v for k,v in vocab["idx2char"].items()} rec_model = mod.Model(1, 256, 256, vocab["num_classes"]) ckpt = torch.load(ckpt_path, map_location=device) rec_model.load_state_dict(ckpt["model_state_dict"]) return cls(det_model, rec_model, idx2char, device=device) def _recognize(self, crop): """Run recognition on a single cropped word image.""" tensor = self.tf(crop).unsqueeze(0).to(self.device) with torch.no_grad(): out = self.rec(tensor) _, preds = out.permute(1,0,2).max(2) chars, prev = [], None for p in preds[0].tolist(): if p != 0 and p != prev: chars.append(self.idx2char.get(p, "")) prev = p return "".join(chars) def _sort_boxes(self, boxes): """ Sort detected boxes in reading order: top-to-bottom, left-to-right within each row. Rows are grouped by vertical proximity. """ if not boxes: return boxes # Sort by y-center first boxes_sorted = sorted(boxes, key=lambda b: (b[1]+b[3])/2) if len(boxes_sorted) == 0: return boxes_sorted # Group into rows (boxes within LINE_THRESH of each other = same row) line_thresh = max(10, (boxes_sorted[0][3] - boxes_sorted[0][1]) * 0.6) rows, current_row = [], [boxes_sorted[0]] for b in boxes_sorted[1:]: cy_prev = (current_row[-1][1] + current_row[-1][3]) / 2 cy_curr = (b[1] + b[3]) / 2 if abs(cy_curr - cy_prev) < line_thresh: current_row.append(b) else: rows.append(sorted(current_row, key=lambda b: b[0])) # sort by x current_row = [b] rows.append(sorted(current_row, key=lambda b: b[0])) return [b for row in rows for b in row] def read_document(self, image_path, conf=0.25): """ Full pipeline: detect → sort → recognize → assemble. Returns dict: text : full document text string items : list of {"bbox": [x1,y1,x2,y2], "text": str, "conf": float} pageCount: 1 """ img = Image.open(image_path).convert("RGB") results = self.det.predict(image_path, conf=conf, verbose=False) boxes = [box.xyxy[0].tolist() + [box.conf[0].item()] for box in results[0].boxes] # Sort into reading order boxes_xy = [[b[0],b[1],b[2],b[3]] for b in boxes] sorted_boxes = self._sort_boxes(boxes_xy) items, texts = [], [] for bbox in sorted_boxes: x1, y1, x2, y2 = [int(v) for v in bbox] crop = img.crop((x1, y1, x2, y2)) if crop.width < 4 or crop.height < 4: continue text = self._recognize(crop) if text.strip(): items.append({"bbox": [x1,y1,x2,y2], "text": text}) texts.append(text) return { "text" : " ".join(texts), "items" : items, "pageCount" : 1 }