Spaces:
Sleeping
Sleeping
| import onnxruntime as ort | |
| import torch | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| import numpy as np | |
| import string | |
| import logging | |
| import os | |
| from typing import List, Tuple | |
| from torch import Tensor | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class TokenDecoder: | |
| def __init__(self): | |
| self.specials_first = ('<eos>',) # [E] | |
| self.specials_last = ('<sos>', '<pad>') # [B], [P] | |
| self.charset = tuple(string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation) | |
| self.itos = self.specials_first + self.charset + self.specials_last | |
| self.stoi = {s: i for i, s in enumerate(self.itos)} | |
| self.eos_id = self.stoi['<eos>'] | |
| self.sos_id = self.stoi['<sos>'] | |
| self.pad_id = self.stoi['<pad>'] | |
| logger.info(f"Initialized TokenDecoder with {len(self.itos)} tokens, including {len(self.charset)} charset tokens.") | |
| def ids2tok(self, token_ids: List[int], join: bool = True) -> str: | |
| tokens = [self.itos[i] for i in token_ids if i < len(self.itos)] # Skip invalid indices | |
| return ''.join(tokens) if join else tokens | |
| def filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: | |
| ids = ids.tolist() | |
| try: | |
| eos_idx = ids.index(self.eos_id) | |
| except ValueError: | |
| eos_idx = len(ids) # No EOS, take all | |
| ids = ids[:eos_idx] # Exclude EOS and beyond | |
| probs = probs[:eos_idx] # Probabilities up to (excluding) EOS | |
| return probs, ids | |
| def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]: | |
| batch_tokens = [] | |
| batch_probs = [] | |
| for dist in token_dists: | |
| probs, ids = dist.max(-1) # Greedy selection | |
| if not raw: | |
| probs, ids = self.filter(probs, ids) | |
| tokens = self.ids2tok(ids) | |
| batch_tokens.append(tokens) | |
| batch_probs.append(probs) | |
| return batch_tokens, batch_probs | |
| def infer_onnx(onnx_path: str, image_path: str) -> None: | |
| try: | |
| # Verify ONNX model exists | |
| if not os.path.exists(onnx_path): | |
| raise FileNotFoundError(f"ONNX model not found at {onnx_path}") | |
| # Initialize ONNX runtime session | |
| logger.info(f"Loading ONNX model from {onnx_path}") | |
| session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider']) | |
| input_name = session.get_inputs()[0].name | |
| # Verify image exists | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image not found at {image_path}") | |
| # Preprocess image | |
| logger.info(f"Processing image {image_path}") | |
| img = Image.open(image_path).convert('RGB') | |
| transform = T.Compose([ | |
| T.Resize((32, 128)), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| img_tensor = transform(img).unsqueeze(0).numpy() # (1, 3, 32, 128) | |
| # Run inference | |
| logger.info("Running inference") | |
| outputs = session.run(None, {input_name: img_tensor})[0] # (1, seq_len, 95) | |
| logits = torch.from_numpy(outputs) | |
| # Decode predictions | |
| decoder = TokenDecoder() | |
| pred, conf_scores = decoder.decode(logits) | |
| logger.info(f"Prediction: {pred[0]}") | |
| logger.info(f"Confidence scores: {conf_scores[0].numpy().tolist()}") | |
| return pred[0] | |
| except Exception as e: | |
| logger.error(f"Error during inference: {str(e)}") | |
| raise | |
| if __name__ == '__main__': | |
| import argparse | |
| parser = argparse.ArgumentParser(description='Perform inference with ONNX model.') | |
| parser.add_argument('--onnx', required=True, help='Path to ONNX model') | |
| parser.add_argument('--image', required=True, help='Path to input CAPTCHA image') | |
| args = parser.parse_args() | |
| infer_onnx(args.onnx, args.image) |