""" Custom handler for BERT-OJA-SkillLess on HF Inference Endpoints. Uses ONNX Runtime with CUDA for 2-4x faster inference. """ from typing import Dict, List, Any import numpy as np from transformers import AutoTokenizer class EndpointHandler: def __init__(self, path=""): self.tokenizer = AutoTokenizer.from_pretrained(path) self.batch_size = 1024 try: from optimum.onnxruntime import ORTModelForSequenceClassification self.model = ORTModelForSequenceClassification.from_pretrained( path, export=True, provider="CUDAExecutionProvider", ) self._use_ort = True print(f"[handler] Loaded ONNX model on CUDA (batch_size={self.batch_size})") except Exception as e: print(f"[handler] ONNX failed ({e}), falling back to PyTorch FP16") import torch from transformers import AutoModelForSequenceClassification self.model = AutoModelForSequenceClassification.from_pretrained(path) self.model.eval() self.device = "cuda" if torch.cuda.is_available() else "cpu" if self.device == "cuda": self.model = self.model.to(self.device).half() self._use_ort = False self._torch = torch def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]: inputs = data.get("inputs", data.get("input", "")) if isinstance(inputs, str): inputs = [inputs] all_results = [] for i in range(0, len(inputs), self.batch_size): batch = inputs[i : i + self.batch_size] encoded = self.tokenizer( batch, padding=True, truncation=True, max_length=128, return_tensors="pt" if not self._use_ort else "np", ) if self._use_ort: logits = self.model(**{k: v for k, v in encoded.items()}).logits if hasattr(logits, 'numpy'): logits = logits.numpy() exp = np.exp(logits - logits.max(axis=-1, keepdims=True)) probs = exp / exp.sum(axis=-1, keepdims=True) for j in range(len(batch)): all_results.append([ {"label": "LABEL_0", "score": round(float(probs[j][0]), 6)}, {"label": "LABEL_1", "score": round(float(probs[j][1]), 6)}, ]) else: torch = self._torch encoded = {k: v.to(self.device) for k, v in encoded.items()} with torch.no_grad(): logits = self.model(**encoded).logits probs = torch.softmax(logits, dim=-1) for j in range(len(batch)): all_results.append([ {"label": "LABEL_0", "score": round(probs[j][0].item(), 6)}, {"label": "LABEL_1", "score": round(probs[j][1].item(), 6)}, ]) return all_results