jtrecenti commited on
Commit
1483b7a
·
1 Parent(s): 0eefeb8

update predict function

Browse files
Files changed (1) hide show
  1. modeling_tjmg.py +25 -0
modeling_tjmg.py CHANGED
@@ -16,6 +16,7 @@ class CaptchaModel(PreTrainedModel):
16
  self.vocab = config.vocab
17
  self.output_ndigits = config.output_ndigits
18
  self.output_vocab_size = config.output_vocab_size
 
19
 
20
  self.batchnorm0 = nn.BatchNorm2d(3)
21
  self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
@@ -48,6 +49,30 @@ class CaptchaModel(PreTrainedModel):
48
  for _ in range(3):
49
  x = self.calc_dim_img_one(x)
50
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def forward(self, x):
52
  # Passagem pela rede (observe que não usamos pipes, apenas chamadas sequenciais)
53
  x = self.batchnorm0(x)
 
16
  self.vocab = config.vocab
17
  self.output_ndigits = config.output_ndigits
18
  self.output_vocab_size = config.output_vocab_size
19
+ self.input_dim = config.input_dim
20
 
21
  self.batchnorm0 = nn.BatchNorm2d(3)
22
  self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
 
49
  for _ in range(3):
50
  x = self.calc_dim_img_one(x)
51
  return x
52
+
53
+ def predict_captcha(self, file_path):
54
+ """
55
+ Realiza a predição do captcha para uma imagem específica.
56
+ """
57
+ # Carrega a imagem e aplica as transformações
58
+ transform = transforms.Compose([
59
+ transforms.Resize(self.input_dim),
60
+ transforms.ToTensor(),
61
+ ])
62
+ image = Image.open(file_path).convert('RGB')
63
+ image = transform(image)
64
+ image = image.unsqueeze(0) # Adiciona uma dimensão para o batch
65
+
66
+ # Realiza a predição
67
+ with torch.no_grad():
68
+ logits = self.forward(image)
69
+
70
+ # Obtém a predição (índice da classe com maior probabilidade)
71
+ preds = torch.argmax(logits, dim=2)
72
+ predicted_label = "".join([self.vocab[i] for i in preds[0].tolist()])
73
+
74
+ return predicted_label
75
+
76
  def forward(self, x):
77
  # Passagem pela rede (observe que não usamos pipes, apenas chamadas sequenciais)
78
  x = self.batchnorm0(x)