bengali-ocr-detection / pipeline.py
Sarjinkhan2003's picture
Bengali detection model — mAP50=0.8790
f8800e2 verified
"""
Bengali OCR — Full Pipeline
Detection (YOLOv8) + Recognition (BengaliCRNN)
Usage:
from pipeline import BengaliDocOCR
ocr = BengaliDocOCR.from_hub()
result = ocr.read_document("page.jpg")
print(result["text"])
"""
import json, os, torch
from pathlib import Path
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
DETECT_REPO = "Sarjinkhan2003/bengali-ocr-detection"
RECOG_REPO = "Sarjinkhan2003/bengali-ocr-recognition"
class BengaliDocOCR:
"""
Full Bengali document OCR pipeline.
Combines:
- YOLOv8n text detection
- LightCRNN text recognition
"""
def __init__(self, det_model, rec_model, idx2char,
img_h=64, img_w=256, device="cpu"):
self.det = det_model
self.rec = rec_model.to(device).eval()
self.idx2char= idx2char
self.device = device
self.img_h = img_h
self.img_w = img_w
self.tf = transforms.Compose([
transforms.Grayscale(1),
transforms.Resize((img_h, img_w)),
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])
])
@classmethod
def from_hub(cls, device="cpu"):
"""Download both models from HuggingFace and build pipeline."""
from ultralytics import YOLO
import importlib.util
# Detection model
det_path = hf_hub_download(DETECT_REPO, "bengali_det.pt")
det_model = YOLO(det_path)
# Recognition model
net_path = hf_hub_download(RECOG_REPO, "bengali_crnn.py")
ckpt_path = hf_hub_download(RECOG_REPO, "bengali_crnn.pth")
vocab_path = hf_hub_download(RECOG_REPO, "vocab.json")
spec = importlib.util.spec_from_file_location("bengali_crnn", net_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
vocab = json.load(open(vocab_path, encoding="utf-8"))
idx2char = {int(k): v for k,v in vocab["idx2char"].items()}
rec_model = mod.Model(1, 256, 256, vocab["num_classes"])
ckpt = torch.load(ckpt_path, map_location=device)
rec_model.load_state_dict(ckpt["model_state_dict"])
return cls(det_model, rec_model, idx2char, device=device)
def _recognize(self, crop):
"""Run recognition on a single cropped word image."""
tensor = self.tf(crop).unsqueeze(0).to(self.device)
with torch.no_grad():
out = self.rec(tensor)
_, preds = out.permute(1,0,2).max(2)
chars, prev = [], None
for p in preds[0].tolist():
if p != 0 and p != prev:
chars.append(self.idx2char.get(p, ""))
prev = p
return "".join(chars)
def _sort_boxes(self, boxes):
"""
Sort detected boxes in reading order:
top-to-bottom, left-to-right within each row.
Rows are grouped by vertical proximity.
"""
if not boxes:
return boxes
# Sort by y-center first
boxes_sorted = sorted(boxes, key=lambda b: (b[1]+b[3])/2)
if len(boxes_sorted) == 0:
return boxes_sorted
# Group into rows (boxes within LINE_THRESH of each other = same row)
line_thresh = max(10, (boxes_sorted[0][3] - boxes_sorted[0][1]) * 0.6)
rows, current_row = [], [boxes_sorted[0]]
for b in boxes_sorted[1:]:
cy_prev = (current_row[-1][1] + current_row[-1][3]) / 2
cy_curr = (b[1] + b[3]) / 2
if abs(cy_curr - cy_prev) < line_thresh:
current_row.append(b)
else:
rows.append(sorted(current_row, key=lambda b: b[0])) # sort by x
current_row = [b]
rows.append(sorted(current_row, key=lambda b: b[0]))
return [b for row in rows for b in row]
def read_document(self, image_path, conf=0.25):
"""
Full pipeline: detect → sort → recognize → assemble.
Returns dict:
text : full document text string
items : list of {"bbox": [x1,y1,x2,y2], "text": str, "conf": float}
pageCount: 1
"""
img = Image.open(image_path).convert("RGB")
results = self.det.predict(image_path, conf=conf, verbose=False)
boxes = [box.xyxy[0].tolist() + [box.conf[0].item()]
for box in results[0].boxes]
# Sort into reading order
boxes_xy = [[b[0],b[1],b[2],b[3]] for b in boxes]
sorted_boxes = self._sort_boxes(boxes_xy)
items, texts = [], []
for bbox in sorted_boxes:
x1, y1, x2, y2 = [int(v) for v in bbox]
crop = img.crop((x1, y1, x2, y2))
if crop.width < 4 or crop.height < 4:
continue
text = self._recognize(crop)
if text.strip():
items.append({"bbox": [x1,y1,x2,y2], "text": text})
texts.append(text)
return {
"text" : " ".join(texts),
"items" : items,
"pageCount" : 1
}