Sarjinkhan2003 commited on
Commit
f964f83
·
verified ·
1 Parent(s): 13d2fbe

Bengali model architecture

Browse files
Files changed (1) hide show
  1. bengali/shobdo_bengali.py +36 -23
bengali/shobdo_bengali.py CHANGED
@@ -1,39 +1,52 @@
1
 
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
  class ConvBnRelu(nn.Module):
6
- def __init__(self,i,o,k=3,s=1,p=1):
7
  super().__init__()
8
- self.b=nn.Sequential(nn.Conv2d(i,o,k,s,p,bias=False),nn.BatchNorm2d(o),nn.ReLU(inplace=True))
9
- def forward(self,x): return self.b(x)
 
 
 
 
10
 
11
  class LightCNN(nn.Module):
12
  def __init__(self):
13
  super().__init__()
14
- self.b1=nn.Sequential(ConvBnRelu(1,32),ConvBnRelu(32,32),nn.MaxPool2d(2,2))
15
- self.b2=nn.Sequential(ConvBnRelu(32,64),ConvBnRelu(64,64),nn.MaxPool2d(2,2))
16
- self.b3=nn.Sequential(ConvBnRelu(64,128),ConvBnRelu(128,128),nn.MaxPool2d((2,1)))
17
- self.b4=nn.Sequential(ConvBnRelu(128,256),ConvBnRelu(256,256),nn.MaxPool2d((2,1)))
18
- self.b5=nn.Sequential(ConvBnRelu(256,256),ConvBnRelu(256,256))
19
- self.pool=nn.AdaptiveAvgPool2d((1,None))
20
- def forward(self,x):
21
- for b in [self.b1,self.b2,self.b3,self.b4,self.b5]: x=b(x)
22
  return self.pool(x).squeeze(2)
23
 
24
- class BiLSTM(nn.Module):
25
- def __init__(self,i,h,o):
26
  super().__init__()
27
- self.rnn=nn.LSTM(i,h,bidirectional=True,batch_first=True)
28
- self.fc=nn.Linear(h*2,o)
29
- def forward(self,x): o,_=self.rnn(x); return self.fc(o)
 
 
30
 
31
  class Model(nn.Module):
32
- def __init__(self,input_channel,output_channel,hidden_size,num_class):
 
33
  super().__init__()
34
- self.cnn=LightCNN()
35
- self.rnn=nn.Sequential(BiLSTM(256,hidden_size,hidden_size),BiLSTM(hidden_size,hidden_size,num_class))
36
- def forward(self,x):
37
- f=self.cnn(x).permute(0,2,1)
38
- o=self.rnn(f).permute(1,0,2)
39
- return F.log_softmax(o,dim=2)
 
 
 
 
 
 
1
 
2
+ import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
 
6
  class ConvBnRelu(nn.Module):
7
+ def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
8
  super().__init__()
9
+ self.block = nn.Sequential(
10
+ nn.Conv2d(in_ch, out_ch, k, s, p, bias=False),
11
+ nn.BatchNorm2d(out_ch),
12
+ nn.ReLU(inplace=True)
13
+ )
14
+ def forward(self, x): return self.block(x)
15
 
16
  class LightCNN(nn.Module):
17
  def __init__(self):
18
  super().__init__()
19
+ self.b1 = nn.Sequential(ConvBnRelu(1,32),ConvBnRelu(32,32),nn.MaxPool2d(2,2))
20
+ self.b2 = nn.Sequential(ConvBnRelu(32,64),ConvBnRelu(64,64),nn.MaxPool2d(2,2))
21
+ self.b3 = nn.Sequential(ConvBnRelu(64,128),ConvBnRelu(128,128),nn.MaxPool2d((2,1)))
22
+ self.b4 = nn.Sequential(ConvBnRelu(128,256),ConvBnRelu(256,256),nn.MaxPool2d((2,1)))
23
+ self.b5 = nn.Sequential(ConvBnRelu(256,256),ConvBnRelu(256,256))
24
+ self.pool = nn.AdaptiveAvgPool2d((1, None))
25
+ def forward(self, x):
26
+ for b in [self.b1,self.b2,self.b3,self.b4,self.b5]: x = b(x)
27
  return self.pool(x).squeeze(2)
28
 
29
+ class BidirectionalLSTM(nn.Module):
30
+ def __init__(self, input_size, hidden_size, output_size):
31
  super().__init__()
32
+ self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
33
+ self.linear = nn.Linear(hidden_size*2, output_size)
34
+ def forward(self, x):
35
+ out, _ = self.rnn(x)
36
+ return self.linear(out)
37
 
38
  class Model(nn.Module):
39
+ """EasyOCR requires the class to be named Model"""
40
+ def __init__(self, input_channel, output_channel, hidden_size, num_class):
41
  super().__init__()
42
+ self.cnn = LightCNN()
43
+ self.rnn = nn.Sequential(
44
+ BidirectionalLSTM(256, hidden_size, hidden_size),
45
+ BidirectionalLSTM(hidden_size, hidden_size, num_class)
46
+ )
47
+ def forward(self, x):
48
+ f = self.cnn(x)
49
+ f = f.permute(0,2,1)
50
+ o = self.rnn(f)
51
+ o = o.permute(1,0,2)
52
+ return F.log_softmax(o, dim=2)