File size: 2,726 Bytes
6342ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

"""
Bengali OCR — standalone inference helper.
Works independently of EasyOCR if needed.

Usage:
    from inference import BengaliOCR
    ocr = BengaliOCR.from_hub("Sarjinkhan2003/bengali-ocr-recognition")
    text = ocr.read("word_image.jpg")
"""
import json, torch
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download

class BengaliOCR:
    IMG_H = 64
    IMG_W = 256

    def __init__(self, model, idx2char, device="cpu"):
        self.model   = model.to(device).eval()
        self.idx2char = idx2char
        self.device  = device
        self.tf = transforms.Compose([
            transforms.Grayscale(1),
            transforms.Resize((self.IMG_H, self.IMG_W)),
            transforms.ToTensor(),
            transforms.Normalize([0.5],[0.5])
        ])

    @classmethod
    def from_hub(cls, repo_id, device="cpu"):
        """Load directly from HuggingFace."""
        import importlib.util, sys
        # Download files
        net_path  = hf_hub_download(repo_id, "bengali_crnn.py")
        ckpt_path = hf_hub_download(repo_id, "bengali_crnn.pth")
        vocab_path= hf_hub_download(repo_id, "vocab.json")
        # Load network class
        spec = importlib.util.spec_from_file_location("bengali_crnn", net_path)
        mod  = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(mod)
        # Load vocab
        with open(vocab_path, encoding="utf-8") as f:
            vocab = json.load(f)
        num_classes = vocab["num_classes"]
        idx2char = {int(k): v for k,v in vocab["idx2char"].items()}
        # Build model
        model = mod.Model(input_channel=1, output_channel=256,
                          hidden_size=256, num_class=num_classes)
        ckpt = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(ckpt["model_state_dict"])
        return cls(model, idx2char, device)

    def read(self, image_path_or_pil, beam_width=5):
        """Read text from a single word/line image."""
        if isinstance(image_path_or_pil, str):
            img = Image.open(image_path_or_pil).convert("RGB")
        else:
            img = image_path_or_pil.convert("RGB")
        tensor = self.tf(img).unsqueeze(0).to(self.device)
        with torch.no_grad():
            out = self.model(tensor)   # [T, 1, C]
        # Greedy decode
        _, preds = out.permute(1,0,2).max(2)
        chars, prev = [], None
        for p in preds[0].tolist():
            if p != 0 and p != prev:
                chars.append(self.idx2char.get(p, ""))
            prev = p
        return "".join(chars)

    def read_batch(self, images):
        """Read a list of PIL images."""
        return [self.read(img) for img in images]