CaptchaPredict / infer_onnx.py
mrbob12's picture
impl
e4083ac
import onnxruntime as ort
import torch
from PIL import Image
import torchvision.transforms as T
import numpy as np
import string
import logging
import os
from typing import List, Tuple
from torch import Tensor
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class TokenDecoder:
def __init__(self):
self.specials_first = ('<eos>',) # [E]
self.specials_last = ('<sos>', '<pad>') # [B], [P]
self.charset = tuple(string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation)
self.itos = self.specials_first + self.charset + self.specials_last
self.stoi = {s: i for i, s in enumerate(self.itos)}
self.eos_id = self.stoi['<eos>']
self.sos_id = self.stoi['<sos>']
self.pad_id = self.stoi['<pad>']
logger.info(f"Initialized TokenDecoder with {len(self.itos)} tokens, including {len(self.charset)} charset tokens.")
def ids2tok(self, token_ids: List[int], join: bool = True) -> str:
tokens = [self.itos[i] for i in token_ids if i < len(self.itos)] # Skip invalid indices
return ''.join(tokens) if join else tokens
def filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
ids = ids.tolist()
try:
eos_idx = ids.index(self.eos_id)
except ValueError:
eos_idx = len(ids) # No EOS, take all
ids = ids[:eos_idx] # Exclude EOS and beyond
probs = probs[:eos_idx] # Probabilities up to (excluding) EOS
return probs, ids
def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
batch_tokens = []
batch_probs = []
for dist in token_dists:
probs, ids = dist.max(-1) # Greedy selection
if not raw:
probs, ids = self.filter(probs, ids)
tokens = self.ids2tok(ids)
batch_tokens.append(tokens)
batch_probs.append(probs)
return batch_tokens, batch_probs
def infer_onnx(onnx_path: str, image_path: str) -> None:
try:
# Verify ONNX model exists
if not os.path.exists(onnx_path):
raise FileNotFoundError(f"ONNX model not found at {onnx_path}")
# Initialize ONNX runtime session
logger.info(f"Loading ONNX model from {onnx_path}")
session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
input_name = session.get_inputs()[0].name
# Verify image exists
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found at {image_path}")
# Preprocess image
logger.info(f"Processing image {image_path}")
img = Image.open(image_path).convert('RGB')
transform = T.Compose([
T.Resize((32, 128)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img_tensor = transform(img).unsqueeze(0).numpy() # (1, 3, 32, 128)
# Run inference
logger.info("Running inference")
outputs = session.run(None, {input_name: img_tensor})[0] # (1, seq_len, 95)
logits = torch.from_numpy(outputs)
# Decode predictions
decoder = TokenDecoder()
pred, conf_scores = decoder.decode(logits)
logger.info(f"Prediction: {pred[0]}")
logger.info(f"Confidence scores: {conf_scores[0].numpy().tolist()}")
return pred[0]
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
raise
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Perform inference with ONNX model.')
parser.add_argument('--onnx', required=True, help='Path to ONNX model')
parser.add_argument('--image', required=True, help='Path to input CAPTCHA image')
args = parser.parse_args()
infer_onnx(args.onnx, args.image)