shobdo-ocr / bengali /shobdo_bengali.py
Sarjinkhan2003's picture
Bengali model architecture
f964f83 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBnRelu(nn.Module):
def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_ch, out_ch, k, s, p, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x): return self.block(x)
class LightCNN(nn.Module):
def __init__(self):
super().__init__()
self.b1 = nn.Sequential(ConvBnRelu(1,32),ConvBnRelu(32,32),nn.MaxPool2d(2,2))
self.b2 = nn.Sequential(ConvBnRelu(32,64),ConvBnRelu(64,64),nn.MaxPool2d(2,2))
self.b3 = nn.Sequential(ConvBnRelu(64,128),ConvBnRelu(128,128),nn.MaxPool2d((2,1)))
self.b4 = nn.Sequential(ConvBnRelu(128,256),ConvBnRelu(256,256),nn.MaxPool2d((2,1)))
self.b5 = nn.Sequential(ConvBnRelu(256,256),ConvBnRelu(256,256))
self.pool = nn.AdaptiveAvgPool2d((1, None))
def forward(self, x):
for b in [self.b1,self.b2,self.b3,self.b4,self.b5]: x = b(x)
return self.pool(x).squeeze(2)
class BidirectionalLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
self.linear = nn.Linear(hidden_size*2, output_size)
def forward(self, x):
out, _ = self.rnn(x)
return self.linear(out)
class Model(nn.Module):
"""EasyOCR requires the class to be named Model"""
def __init__(self, input_channel, output_channel, hidden_size, num_class):
super().__init__()
self.cnn = LightCNN()
self.rnn = nn.Sequential(
BidirectionalLSTM(256, hidden_size, hidden_size),
BidirectionalLSTM(hidden_size, hidden_size, num_class)
)
def forward(self, x):
f = self.cnn(x)
f = f.permute(0,2,1)
o = self.rnn(f)
o = o.permute(1,0,2)
return F.log_softmax(o, dim=2)