Sarjinkhan2003 commited on
Commit
3bbb1c3
·
verified ·
1 Parent(s): f1b119d

Script classifier architecture

Browse files
Files changed (1) hide show
  1. classifier/shobdo_classifier.py +26 -0
classifier/shobdo_classifier.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn as nn
3
+
4
+ class ScriptClassifier(nn.Module):
5
+ """Lightweight Bengali/English script classifier. 23K params, ~0.1MB."""
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.features = nn.Sequential(
9
+ nn.Conv2d(1,16,3,padding=1), nn.BatchNorm2d(16), nn.ReLU(True),
10
+ nn.MaxPool2d(2),
11
+ nn.Conv2d(16,32,3,padding=1), nn.BatchNorm2d(32), nn.ReLU(True),
12
+ nn.MaxPool2d(2),
13
+ nn.Conv2d(32,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU(True),
14
+ nn.MaxPool2d(2),
15
+ nn.Conv2d(64,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU(True),
16
+ nn.AdaptiveAvgPool2d((1,1)),
17
+ )
18
+ self.classifier = nn.Sequential(
19
+ nn.Flatten(), nn.Dropout(0.3), nn.Linear(64,2)
20
+ )
21
+ def forward(self, x):
22
+ return self.classifier(self.features(x))
23
+ def predict(self, x):
24
+ """x: (1,1,64,256) tensor. Returns 'bengali' or 'english'."""
25
+ with __import__('torch').no_grad():
26
+ return ['bengali','english'][self.forward(x).argmax(1).item()]