bengali-ocr-recognition / inference.py
Sarjinkhan2003's picture
Bengali OCR recognition model — CER=0.0062
6342ae9 verified
"""
Bengali OCR — standalone inference helper.
Works independently of EasyOCR if needed.
Usage:
from inference import BengaliOCR
ocr = BengaliOCR.from_hub("Sarjinkhan2003/bengali-ocr-recognition")
text = ocr.read("word_image.jpg")
"""
import json, torch
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
class BengaliOCR:
IMG_H = 64
IMG_W = 256
def __init__(self, model, idx2char, device="cpu"):
self.model = model.to(device).eval()
self.idx2char = idx2char
self.device = device
self.tf = transforms.Compose([
transforms.Grayscale(1),
transforms.Resize((self.IMG_H, self.IMG_W)),
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])
])
@classmethod
def from_hub(cls, repo_id, device="cpu"):
"""Load directly from HuggingFace."""
import importlib.util, sys
# Download files
net_path = hf_hub_download(repo_id, "bengali_crnn.py")
ckpt_path = hf_hub_download(repo_id, "bengali_crnn.pth")
vocab_path= hf_hub_download(repo_id, "vocab.json")
# Load network class
spec = importlib.util.spec_from_file_location("bengali_crnn", net_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
# Load vocab
with open(vocab_path, encoding="utf-8") as f:
vocab = json.load(f)
num_classes = vocab["num_classes"]
idx2char = {int(k): v for k,v in vocab["idx2char"].items()}
# Build model
model = mod.Model(input_channel=1, output_channel=256,
hidden_size=256, num_class=num_classes)
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
return cls(model, idx2char, device)
def read(self, image_path_or_pil, beam_width=5):
"""Read text from a single word/line image."""
if isinstance(image_path_or_pil, str):
img = Image.open(image_path_or_pil).convert("RGB")
else:
img = image_path_or_pil.convert("RGB")
tensor = self.tf(img).unsqueeze(0).to(self.device)
with torch.no_grad():
out = self.model(tensor) # [T, 1, C]
# Greedy decode
_, preds = out.permute(1,0,2).max(2)
chars, prev = [], None
for p in preds[0].tolist():
if p != 0 and p != prev:
chars.append(self.idx2char.get(p, ""))
prev = p
return "".join(chars)
def read_batch(self, images):
"""Read a list of PIL images."""
return [self.read(img) for img in images]