update predict function
Browse files- modeling_tjmg.py +25 -0
modeling_tjmg.py
CHANGED
|
@@ -16,6 +16,7 @@ class CaptchaModel(PreTrainedModel):
|
|
| 16 |
self.vocab = config.vocab
|
| 17 |
self.output_ndigits = config.output_ndigits
|
| 18 |
self.output_vocab_size = config.output_vocab_size
|
|
|
|
| 19 |
|
| 20 |
self.batchnorm0 = nn.BatchNorm2d(3)
|
| 21 |
self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
|
|
@@ -48,6 +49,30 @@ class CaptchaModel(PreTrainedModel):
|
|
| 48 |
for _ in range(3):
|
| 49 |
x = self.calc_dim_img_one(x)
|
| 50 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
def forward(self, x):
|
| 52 |
# Passagem pela rede (observe que não usamos pipes, apenas chamadas sequenciais)
|
| 53 |
x = self.batchnorm0(x)
|
|
|
|
| 16 |
self.vocab = config.vocab
|
| 17 |
self.output_ndigits = config.output_ndigits
|
| 18 |
self.output_vocab_size = config.output_vocab_size
|
| 19 |
+
self.input_dim = config.input_dim
|
| 20 |
|
| 21 |
self.batchnorm0 = nn.BatchNorm2d(3)
|
| 22 |
self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
|
|
|
|
| 49 |
for _ in range(3):
|
| 50 |
x = self.calc_dim_img_one(x)
|
| 51 |
return x
|
| 52 |
+
|
| 53 |
+
def predict_captcha(self, file_path):
|
| 54 |
+
"""
|
| 55 |
+
Realiza a predição do captcha para uma imagem específica.
|
| 56 |
+
"""
|
| 57 |
+
# Carrega a imagem e aplica as transformações
|
| 58 |
+
transform = transforms.Compose([
|
| 59 |
+
transforms.Resize(self.input_dim),
|
| 60 |
+
transforms.ToTensor(),
|
| 61 |
+
])
|
| 62 |
+
image = Image.open(file_path).convert('RGB')
|
| 63 |
+
image = transform(image)
|
| 64 |
+
image = image.unsqueeze(0) # Adiciona uma dimensão para o batch
|
| 65 |
+
|
| 66 |
+
# Realiza a predição
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
logits = self.forward(image)
|
| 69 |
+
|
| 70 |
+
# Obtém a predição (índice da classe com maior probabilidade)
|
| 71 |
+
preds = torch.argmax(logits, dim=2)
|
| 72 |
+
predicted_label = "".join([self.vocab[i] for i in preds[0].tolist()])
|
| 73 |
+
|
| 74 |
+
return predicted_label
|
| 75 |
+
|
| 76 |
def forward(self, x):
|
| 77 |
# Passagem pela rede (observe que não usamos pipes, apenas chamadas sequenciais)
|
| 78 |
x = self.batchnorm0(x)
|