mrbob12 commited on
Commit
e4083ac
·
1 Parent(s): b076cce
Files changed (3) hide show
  1. app.py +18 -0
  2. infer_onnx.py +102 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from infer_onnx import infer_onnx
3
+
4
+ def predict_from_image(image):
5
+ image.save("uploaded.png")
6
+ prediction = infer_onnx("model.onnx", "uploaded.png")
7
+ return prediction
8
+
9
+ iface = gr.Interface(
10
+ fn=predict_from_image,
11
+ inputs=gr.Image(type="pil"),
12
+ outputs="text",
13
+ title="ONNX CAPTCHA Inference",
14
+ description="Upload an image to get prediction from ONNX model."
15
+ )
16
+
17
+ if __name__ == "__main__":
18
+ iface.launch()
infer_onnx.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ import torch
3
+ from PIL import Image
4
+ import torchvision.transforms as T
5
+ import numpy as np
6
+ import string
7
+ import logging
8
+ import os
9
+ from typing import List, Tuple
10
+ from torch import Tensor
11
+
12
+ # Set up logging
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class TokenDecoder:
17
+ def __init__(self):
18
+ self.specials_first = ('<eos>',) # [E]
19
+ self.specials_last = ('<sos>', '<pad>') # [B], [P]
20
+ self.charset = tuple(string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation)
21
+ self.itos = self.specials_first + self.charset + self.specials_last
22
+ self.stoi = {s: i for i, s in enumerate(self.itos)}
23
+ self.eos_id = self.stoi['<eos>']
24
+ self.sos_id = self.stoi['<sos>']
25
+ self.pad_id = self.stoi['<pad>']
26
+ logger.info(f"Initialized TokenDecoder with {len(self.itos)} tokens, including {len(self.charset)} charset tokens.")
27
+
28
+ def ids2tok(self, token_ids: List[int], join: bool = True) -> str:
29
+ tokens = [self.itos[i] for i in token_ids if i < len(self.itos)] # Skip invalid indices
30
+ return ''.join(tokens) if join else tokens
31
+
32
+ def filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
33
+ ids = ids.tolist()
34
+ try:
35
+ eos_idx = ids.index(self.eos_id)
36
+ except ValueError:
37
+ eos_idx = len(ids) # No EOS, take all
38
+ ids = ids[:eos_idx] # Exclude EOS and beyond
39
+ probs = probs[:eos_idx] # Probabilities up to (excluding) EOS
40
+ return probs, ids
41
+
42
+ def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
43
+ batch_tokens = []
44
+ batch_probs = []
45
+ for dist in token_dists:
46
+ probs, ids = dist.max(-1) # Greedy selection
47
+ if not raw:
48
+ probs, ids = self.filter(probs, ids)
49
+ tokens = self.ids2tok(ids)
50
+ batch_tokens.append(tokens)
51
+ batch_probs.append(probs)
52
+ return batch_tokens, batch_probs
53
+
54
+ def infer_onnx(onnx_path: str, image_path: str) -> None:
55
+ try:
56
+ # Verify ONNX model exists
57
+ if not os.path.exists(onnx_path):
58
+ raise FileNotFoundError(f"ONNX model not found at {onnx_path}")
59
+
60
+ # Initialize ONNX runtime session
61
+ logger.info(f"Loading ONNX model from {onnx_path}")
62
+ session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
63
+ input_name = session.get_inputs()[0].name
64
+
65
+ # Verify image exists
66
+ if not os.path.exists(image_path):
67
+ raise FileNotFoundError(f"Image not found at {image_path}")
68
+
69
+ # Preprocess image
70
+ logger.info(f"Processing image {image_path}")
71
+ img = Image.open(image_path).convert('RGB')
72
+ transform = T.Compose([
73
+ T.Resize((32, 128)),
74
+ T.ToTensor(),
75
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
76
+ ])
77
+ img_tensor = transform(img).unsqueeze(0).numpy() # (1, 3, 32, 128)
78
+
79
+ # Run inference
80
+ logger.info("Running inference")
81
+ outputs = session.run(None, {input_name: img_tensor})[0] # (1, seq_len, 95)
82
+ logits = torch.from_numpy(outputs)
83
+
84
+ # Decode predictions
85
+ decoder = TokenDecoder()
86
+ pred, conf_scores = decoder.decode(logits)
87
+ logger.info(f"Prediction: {pred[0]}")
88
+ logger.info(f"Confidence scores: {conf_scores[0].numpy().tolist()}")
89
+
90
+ return pred[0]
91
+
92
+ except Exception as e:
93
+ logger.error(f"Error during inference: {str(e)}")
94
+ raise
95
+
96
+ if __name__ == '__main__':
97
+ import argparse
98
+ parser = argparse.ArgumentParser(description='Perform inference with ONNX model.')
99
+ parser.add_argument('--onnx', required=True, help='Path to ONNX model')
100
+ parser.add_argument('--image', required=True, help='Path to input CAPTCHA image')
101
+ args = parser.parse_args()
102
+ infer_onnx(args.onnx, args.image)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pytorch-lightning
4
+ pillow
5
+ onnx
6
+ onnxruntime
7
+ flask
8
+ gradio