|
|
| """ |
| Bengali OCR — standalone inference helper. |
| Works independently of EasyOCR if needed. |
| |
| Usage: |
| from inference import BengaliOCR |
| ocr = BengaliOCR.from_hub("Sarjinkhan2003/bengali-ocr-recognition") |
| text = ocr.read("word_image.jpg") |
| """ |
| import json, torch |
| from PIL import Image |
| from torchvision import transforms |
| from huggingface_hub import hf_hub_download |
|
|
| class BengaliOCR: |
| IMG_H = 64 |
| IMG_W = 256 |
|
|
| def __init__(self, model, idx2char, device="cpu"): |
| self.model = model.to(device).eval() |
| self.idx2char = idx2char |
| self.device = device |
| self.tf = transforms.Compose([ |
| transforms.Grayscale(1), |
| transforms.Resize((self.IMG_H, self.IMG_W)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5],[0.5]) |
| ]) |
|
|
| @classmethod |
| def from_hub(cls, repo_id, device="cpu"): |
| """Load directly from HuggingFace.""" |
| import importlib.util, sys |
| |
| net_path = hf_hub_download(repo_id, "bengali_crnn.py") |
| ckpt_path = hf_hub_download(repo_id, "bengali_crnn.pth") |
| vocab_path= hf_hub_download(repo_id, "vocab.json") |
| |
| spec = importlib.util.spec_from_file_location("bengali_crnn", net_path) |
| mod = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(mod) |
| |
| with open(vocab_path, encoding="utf-8") as f: |
| vocab = json.load(f) |
| num_classes = vocab["num_classes"] |
| idx2char = {int(k): v for k,v in vocab["idx2char"].items()} |
| |
| model = mod.Model(input_channel=1, output_channel=256, |
| hidden_size=256, num_class=num_classes) |
| ckpt = torch.load(ckpt_path, map_location=device) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| return cls(model, idx2char, device) |
|
|
| def read(self, image_path_or_pil, beam_width=5): |
| """Read text from a single word/line image.""" |
| if isinstance(image_path_or_pil, str): |
| img = Image.open(image_path_or_pil).convert("RGB") |
| else: |
| img = image_path_or_pil.convert("RGB") |
| tensor = self.tf(img).unsqueeze(0).to(self.device) |
| with torch.no_grad(): |
| out = self.model(tensor) |
| |
| _, preds = out.permute(1,0,2).max(2) |
| chars, prev = [], None |
| for p in preds[0].tolist(): |
| if p != 0 and p != prev: |
| chars.append(self.idx2char.get(p, "")) |
| prev = p |
| return "".join(chars) |
|
|
| def read_batch(self, images): |
| """Read a list of PIL images.""" |
| return [self.read(img) for img in images] |
|
|