|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class ConvBnRelu(nn.Module): |
| def __init__(self, in_ch, out_ch, k=3, s=1, p=1): |
| super().__init__() |
| self.block = nn.Sequential( |
| nn.Conv2d(in_ch, out_ch, k, s, p, bias=False), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(inplace=True) |
| ) |
| def forward(self, x): return self.block(x) |
|
|
| class LightCNN(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.b1 = nn.Sequential(ConvBnRelu(1,32),ConvBnRelu(32,32),nn.MaxPool2d(2,2)) |
| self.b2 = nn.Sequential(ConvBnRelu(32,64),ConvBnRelu(64,64),nn.MaxPool2d(2,2)) |
| self.b3 = nn.Sequential(ConvBnRelu(64,128),ConvBnRelu(128,128),nn.MaxPool2d((2,1))) |
| self.b4 = nn.Sequential(ConvBnRelu(128,256),ConvBnRelu(256,256),nn.MaxPool2d((2,1))) |
| self.b5 = nn.Sequential(ConvBnRelu(256,256),ConvBnRelu(256,256)) |
| self.pool = nn.AdaptiveAvgPool2d((1, None)) |
| def forward(self, x): |
| for b in [self.b1,self.b2,self.b3,self.b4,self.b5]: x = b(x) |
| return self.pool(x).squeeze(2) |
|
|
| class BidirectionalLSTM(nn.Module): |
| def __init__(self, input_size, hidden_size, output_size): |
| super().__init__() |
| self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) |
| self.linear = nn.Linear(hidden_size*2, output_size) |
| def forward(self, x): |
| out, _ = self.rnn(x) |
| return self.linear(out) |
|
|
| class Model(nn.Module): |
| """EasyOCR requires the class to be named Model""" |
| def __init__(self, input_channel, output_channel, hidden_size, num_class): |
| super().__init__() |
| self.cnn = LightCNN() |
| self.rnn = nn.Sequential( |
| BidirectionalLSTM(256, hidden_size, hidden_size), |
| BidirectionalLSTM(hidden_size, hidden_size, num_class) |
| ) |
| def forward(self, x): |
| f = self.cnn(x) |
| f = f.permute(0,2,1) |
| o = self.rnn(f) |
| o = o.permute(1,0,2) |
| return F.log_softmax(o, dim=2) |
|
|