File size: 5,111 Bytes
f8800e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
"""
Bengali OCR — Full Pipeline
Detection (YOLOv8) + Recognition (BengaliCRNN)
Usage:
from pipeline import BengaliDocOCR
ocr = BengaliDocOCR.from_hub()
result = ocr.read_document("page.jpg")
print(result["text"])
"""
import json, os, torch
from pathlib import Path
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
DETECT_REPO = "Sarjinkhan2003/bengali-ocr-detection"
RECOG_REPO = "Sarjinkhan2003/bengali-ocr-recognition"
class BengaliDocOCR:
"""
Full Bengali document OCR pipeline.
Combines:
- YOLOv8n text detection
- LightCRNN text recognition
"""
def __init__(self, det_model, rec_model, idx2char,
img_h=64, img_w=256, device="cpu"):
self.det = det_model
self.rec = rec_model.to(device).eval()
self.idx2char= idx2char
self.device = device
self.img_h = img_h
self.img_w = img_w
self.tf = transforms.Compose([
transforms.Grayscale(1),
transforms.Resize((img_h, img_w)),
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])
])
@classmethod
def from_hub(cls, device="cpu"):
"""Download both models from HuggingFace and build pipeline."""
from ultralytics import YOLO
import importlib.util
# Detection model
det_path = hf_hub_download(DETECT_REPO, "bengali_det.pt")
det_model = YOLO(det_path)
# Recognition model
net_path = hf_hub_download(RECOG_REPO, "bengali_crnn.py")
ckpt_path = hf_hub_download(RECOG_REPO, "bengali_crnn.pth")
vocab_path = hf_hub_download(RECOG_REPO, "vocab.json")
spec = importlib.util.spec_from_file_location("bengali_crnn", net_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
vocab = json.load(open(vocab_path, encoding="utf-8"))
idx2char = {int(k): v for k,v in vocab["idx2char"].items()}
rec_model = mod.Model(1, 256, 256, vocab["num_classes"])
ckpt = torch.load(ckpt_path, map_location=device)
rec_model.load_state_dict(ckpt["model_state_dict"])
return cls(det_model, rec_model, idx2char, device=device)
def _recognize(self, crop):
"""Run recognition on a single cropped word image."""
tensor = self.tf(crop).unsqueeze(0).to(self.device)
with torch.no_grad():
out = self.rec(tensor)
_, preds = out.permute(1,0,2).max(2)
chars, prev = [], None
for p in preds[0].tolist():
if p != 0 and p != prev:
chars.append(self.idx2char.get(p, ""))
prev = p
return "".join(chars)
def _sort_boxes(self, boxes):
"""
Sort detected boxes in reading order:
top-to-bottom, left-to-right within each row.
Rows are grouped by vertical proximity.
"""
if not boxes:
return boxes
# Sort by y-center first
boxes_sorted = sorted(boxes, key=lambda b: (b[1]+b[3])/2)
if len(boxes_sorted) == 0:
return boxes_sorted
# Group into rows (boxes within LINE_THRESH of each other = same row)
line_thresh = max(10, (boxes_sorted[0][3] - boxes_sorted[0][1]) * 0.6)
rows, current_row = [], [boxes_sorted[0]]
for b in boxes_sorted[1:]:
cy_prev = (current_row[-1][1] + current_row[-1][3]) / 2
cy_curr = (b[1] + b[3]) / 2
if abs(cy_curr - cy_prev) < line_thresh:
current_row.append(b)
else:
rows.append(sorted(current_row, key=lambda b: b[0])) # sort by x
current_row = [b]
rows.append(sorted(current_row, key=lambda b: b[0]))
return [b for row in rows for b in row]
def read_document(self, image_path, conf=0.25):
"""
Full pipeline: detect → sort → recognize → assemble.
Returns dict:
text : full document text string
items : list of {"bbox": [x1,y1,x2,y2], "text": str, "conf": float}
pageCount: 1
"""
img = Image.open(image_path).convert("RGB")
results = self.det.predict(image_path, conf=conf, verbose=False)
boxes = [box.xyxy[0].tolist() + [box.conf[0].item()]
for box in results[0].boxes]
# Sort into reading order
boxes_xy = [[b[0],b[1],b[2],b[3]] for b in boxes]
sorted_boxes = self._sort_boxes(boxes_xy)
items, texts = [], []
for bbox in sorted_boxes:
x1, y1, x2, y2 = [int(v) for v in bbox]
crop = img.crop((x1, y1, x2, y2))
if crop.width < 4 or crop.height < 4:
continue
text = self._recognize(crop)
if text.strip():
items.append({"bbox": [x1,y1,x2,y2], "text": text})
texts.append(text)
return {
"text" : " ".join(texts),
"items" : items,
"pageCount" : 1
}
|