File size: 5,111 Bytes
f8800e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

"""
Bengali OCR — Full Pipeline
Detection (YOLOv8) + Recognition (BengaliCRNN)

Usage:
    from pipeline import BengaliDocOCR
    ocr = BengaliDocOCR.from_hub()
    result = ocr.read_document("page.jpg")
    print(result["text"])
"""
import json, os, torch
from pathlib import Path
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download

DETECT_REPO = "Sarjinkhan2003/bengali-ocr-detection"
RECOG_REPO  = "Sarjinkhan2003/bengali-ocr-recognition"


class BengaliDocOCR:
    """
    Full Bengali document OCR pipeline.
    Combines:
      - YOLOv8n text detection
      - LightCRNN text recognition
    """

    def __init__(self, det_model, rec_model, idx2char,
                 img_h=64, img_w=256, device="cpu"):
        self.det     = det_model
        self.rec     = rec_model.to(device).eval()
        self.idx2char= idx2char
        self.device  = device
        self.img_h   = img_h
        self.img_w   = img_w
        self.tf = transforms.Compose([
            transforms.Grayscale(1),
            transforms.Resize((img_h, img_w)),
            transforms.ToTensor(),
            transforms.Normalize([0.5],[0.5])
        ])

    @classmethod
    def from_hub(cls, device="cpu"):
        """Download both models from HuggingFace and build pipeline."""
        from ultralytics import YOLO
        import importlib.util

        # Detection model
        det_path = hf_hub_download(DETECT_REPO, "bengali_det.pt")
        det_model = YOLO(det_path)

        # Recognition model
        net_path   = hf_hub_download(RECOG_REPO, "bengali_crnn.py")
        ckpt_path  = hf_hub_download(RECOG_REPO, "bengali_crnn.pth")
        vocab_path = hf_hub_download(RECOG_REPO, "vocab.json")

        spec = importlib.util.spec_from_file_location("bengali_crnn", net_path)
        mod  = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(mod)

        vocab    = json.load(open(vocab_path, encoding="utf-8"))
        idx2char = {int(k): v for k,v in vocab["idx2char"].items()}
        rec_model = mod.Model(1, 256, 256, vocab["num_classes"])
        ckpt = torch.load(ckpt_path, map_location=device)
        rec_model.load_state_dict(ckpt["model_state_dict"])

        return cls(det_model, rec_model, idx2char, device=device)

    def _recognize(self, crop):
        """Run recognition on a single cropped word image."""
        tensor = self.tf(crop).unsqueeze(0).to(self.device)
        with torch.no_grad():
            out = self.rec(tensor)
        _, preds = out.permute(1,0,2).max(2)
        chars, prev = [], None
        for p in preds[0].tolist():
            if p != 0 and p != prev:
                chars.append(self.idx2char.get(p, ""))
            prev = p
        return "".join(chars)

    def _sort_boxes(self, boxes):
        """
        Sort detected boxes in reading order:
        top-to-bottom, left-to-right within each row.
        Rows are grouped by vertical proximity.
        """
        if not boxes:
            return boxes
        # Sort by y-center first
        boxes_sorted = sorted(boxes, key=lambda b: (b[1]+b[3])/2)
        if len(boxes_sorted) == 0:
            return boxes_sorted
        # Group into rows (boxes within LINE_THRESH of each other = same row)
        line_thresh = max(10, (boxes_sorted[0][3] - boxes_sorted[0][1]) * 0.6)
        rows, current_row = [], [boxes_sorted[0]]
        for b in boxes_sorted[1:]:
            cy_prev = (current_row[-1][1] + current_row[-1][3]) / 2
            cy_curr = (b[1] + b[3]) / 2
            if abs(cy_curr - cy_prev) < line_thresh:
                current_row.append(b)
            else:
                rows.append(sorted(current_row, key=lambda b: b[0]))  # sort by x
                current_row = [b]
        rows.append(sorted(current_row, key=lambda b: b[0]))
        return [b for row in rows for b in row]

    def read_document(self, image_path, conf=0.25):
        """
        Full pipeline: detect → sort → recognize → assemble.

        Returns dict:
          text     : full document text string
          items    : list of {"bbox": [x1,y1,x2,y2], "text": str, "conf": float}
          pageCount: 1
        """
        img     = Image.open(image_path).convert("RGB")
        results = self.det.predict(image_path, conf=conf, verbose=False)
        boxes   = [box.xyxy[0].tolist() + [box.conf[0].item()]
                   for box in results[0].boxes]

        # Sort into reading order
        boxes_xy = [[b[0],b[1],b[2],b[3]] for b in boxes]
        sorted_boxes = self._sort_boxes(boxes_xy)

        items, texts = [], []
        for bbox in sorted_boxes:
            x1, y1, x2, y2 = [int(v) for v in bbox]
            crop = img.crop((x1, y1, x2, y2))
            if crop.width < 4 or crop.height < 4:
                continue
            text = self._recognize(crop)
            if text.strip():
                items.append({"bbox": [x1,y1,x2,y2], "text": text})
                texts.append(text)

        return {
            "text"      : " ".join(texts),
            "items"     : items,
            "pageCount" : 1
        }