""" Bengali OCR — standalone inference helper. Works independently of EasyOCR if needed. Usage: from inference import BengaliOCR ocr = BengaliOCR.from_hub("Sarjinkhan2003/bengali-ocr-recognition") text = ocr.read("word_image.jpg") """ import json, torch from PIL import Image from torchvision import transforms from huggingface_hub import hf_hub_download class BengaliOCR: IMG_H = 64 IMG_W = 256 def __init__(self, model, idx2char, device="cpu"): self.model = model.to(device).eval() self.idx2char = idx2char self.device = device self.tf = transforms.Compose([ transforms.Grayscale(1), transforms.Resize((self.IMG_H, self.IMG_W)), transforms.ToTensor(), transforms.Normalize([0.5],[0.5]) ]) @classmethod def from_hub(cls, repo_id, device="cpu"): """Load directly from HuggingFace.""" import importlib.util, sys # Download files net_path = hf_hub_download(repo_id, "bengali_crnn.py") ckpt_path = hf_hub_download(repo_id, "bengali_crnn.pth") vocab_path= hf_hub_download(repo_id, "vocab.json") # Load network class spec = importlib.util.spec_from_file_location("bengali_crnn", net_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) # Load vocab with open(vocab_path, encoding="utf-8") as f: vocab = json.load(f) num_classes = vocab["num_classes"] idx2char = {int(k): v for k,v in vocab["idx2char"].items()} # Build model model = mod.Model(input_channel=1, output_channel=256, hidden_size=256, num_class=num_classes) ckpt = torch.load(ckpt_path, map_location=device) model.load_state_dict(ckpt["model_state_dict"]) return cls(model, idx2char, device) def read(self, image_path_or_pil, beam_width=5): """Read text from a single word/line image.""" if isinstance(image_path_or_pil, str): img = Image.open(image_path_or_pil).convert("RGB") else: img = image_path_or_pil.convert("RGB") tensor = self.tf(img).unsqueeze(0).to(self.device) with torch.no_grad(): out = self.model(tensor) # [T, 1, C] # Greedy decode _, preds = out.permute(1,0,2).max(2) chars, prev = [], None for p in preds[0].tolist(): if p != 0 and p != prev: chars.append(self.idx2char.get(p, "")) prev = p return "".join(chars) def read_batch(self, images): """Read a list of PIL images.""" return [self.read(img) for img in images]