|
|
| """ |
| Bengali OCR — Full Pipeline |
| Detection (YOLOv8) + Recognition (BengaliCRNN) |
| |
| Usage: |
| from pipeline import BengaliDocOCR |
| ocr = BengaliDocOCR.from_hub() |
| result = ocr.read_document("page.jpg") |
| print(result["text"]) |
| """ |
| import json, os, torch |
| from pathlib import Path |
| from PIL import Image |
| from torchvision import transforms |
| from huggingface_hub import hf_hub_download |
|
|
| DETECT_REPO = "Sarjinkhan2003/bengali-ocr-detection" |
| RECOG_REPO = "Sarjinkhan2003/bengali-ocr-recognition" |
|
|
|
|
| class BengaliDocOCR: |
| """ |
| Full Bengali document OCR pipeline. |
| Combines: |
| - YOLOv8n text detection |
| - LightCRNN text recognition |
| """ |
|
|
| def __init__(self, det_model, rec_model, idx2char, |
| img_h=64, img_w=256, device="cpu"): |
| self.det = det_model |
| self.rec = rec_model.to(device).eval() |
| self.idx2char= idx2char |
| self.device = device |
| self.img_h = img_h |
| self.img_w = img_w |
| self.tf = transforms.Compose([ |
| transforms.Grayscale(1), |
| transforms.Resize((img_h, img_w)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5],[0.5]) |
| ]) |
|
|
| @classmethod |
| def from_hub(cls, device="cpu"): |
| """Download both models from HuggingFace and build pipeline.""" |
| from ultralytics import YOLO |
| import importlib.util |
|
|
| |
| det_path = hf_hub_download(DETECT_REPO, "bengali_det.pt") |
| det_model = YOLO(det_path) |
|
|
| |
| net_path = hf_hub_download(RECOG_REPO, "bengali_crnn.py") |
| ckpt_path = hf_hub_download(RECOG_REPO, "bengali_crnn.pth") |
| vocab_path = hf_hub_download(RECOG_REPO, "vocab.json") |
|
|
| spec = importlib.util.spec_from_file_location("bengali_crnn", net_path) |
| mod = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(mod) |
|
|
| vocab = json.load(open(vocab_path, encoding="utf-8")) |
| idx2char = {int(k): v for k,v in vocab["idx2char"].items()} |
| rec_model = mod.Model(1, 256, 256, vocab["num_classes"]) |
| ckpt = torch.load(ckpt_path, map_location=device) |
| rec_model.load_state_dict(ckpt["model_state_dict"]) |
|
|
| return cls(det_model, rec_model, idx2char, device=device) |
|
|
| def _recognize(self, crop): |
| """Run recognition on a single cropped word image.""" |
| tensor = self.tf(crop).unsqueeze(0).to(self.device) |
| with torch.no_grad(): |
| out = self.rec(tensor) |
| _, preds = out.permute(1,0,2).max(2) |
| chars, prev = [], None |
| for p in preds[0].tolist(): |
| if p != 0 and p != prev: |
| chars.append(self.idx2char.get(p, "")) |
| prev = p |
| return "".join(chars) |
|
|
| def _sort_boxes(self, boxes): |
| """ |
| Sort detected boxes in reading order: |
| top-to-bottom, left-to-right within each row. |
| Rows are grouped by vertical proximity. |
| """ |
| if not boxes: |
| return boxes |
| |
| boxes_sorted = sorted(boxes, key=lambda b: (b[1]+b[3])/2) |
| if len(boxes_sorted) == 0: |
| return boxes_sorted |
| |
| line_thresh = max(10, (boxes_sorted[0][3] - boxes_sorted[0][1]) * 0.6) |
| rows, current_row = [], [boxes_sorted[0]] |
| for b in boxes_sorted[1:]: |
| cy_prev = (current_row[-1][1] + current_row[-1][3]) / 2 |
| cy_curr = (b[1] + b[3]) / 2 |
| if abs(cy_curr - cy_prev) < line_thresh: |
| current_row.append(b) |
| else: |
| rows.append(sorted(current_row, key=lambda b: b[0])) |
| current_row = [b] |
| rows.append(sorted(current_row, key=lambda b: b[0])) |
| return [b for row in rows for b in row] |
|
|
| def read_document(self, image_path, conf=0.25): |
| """ |
| Full pipeline: detect → sort → recognize → assemble. |
| |
| Returns dict: |
| text : full document text string |
| items : list of {"bbox": [x1,y1,x2,y2], "text": str, "conf": float} |
| pageCount: 1 |
| """ |
| img = Image.open(image_path).convert("RGB") |
| results = self.det.predict(image_path, conf=conf, verbose=False) |
| boxes = [box.xyxy[0].tolist() + [box.conf[0].item()] |
| for box in results[0].boxes] |
|
|
| |
| boxes_xy = [[b[0],b[1],b[2],b[3]] for b in boxes] |
| sorted_boxes = self._sort_boxes(boxes_xy) |
|
|
| items, texts = [], [] |
| for bbox in sorted_boxes: |
| x1, y1, x2, y2 = [int(v) for v in bbox] |
| crop = img.crop((x1, y1, x2, y2)) |
| if crop.width < 4 or crop.height < 4: |
| continue |
| text = self._recognize(crop) |
| if text.strip(): |
| items.append({"bbox": [x1,y1,x2,y2], "text": text}) |
| texts.append(text) |
|
|
| return { |
| "text" : " ".join(texts), |
| "items" : items, |
| "pageCount" : 1 |
| } |
|
|