| """ |
| Custom handler for BERT-OJA-SkillLess on HF Inference Endpoints. |
| Uses ONNX Runtime with CUDA for 2-4x faster inference. |
| """ |
| from typing import Dict, List, Any |
| import numpy as np |
| from transformers import AutoTokenizer |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| self.tokenizer = AutoTokenizer.from_pretrained(path) |
| self.batch_size = 1024 |
|
|
| try: |
| from optimum.onnxruntime import ORTModelForSequenceClassification |
| self.model = ORTModelForSequenceClassification.from_pretrained( |
| path, export=True, provider="CUDAExecutionProvider", |
| ) |
| self._use_ort = True |
| print(f"[handler] Loaded ONNX model on CUDA (batch_size={self.batch_size})") |
| except Exception as e: |
| print(f"[handler] ONNX failed ({e}), falling back to PyTorch FP16") |
| import torch |
| from transformers import AutoModelForSequenceClassification |
| self.model = AutoModelForSequenceClassification.from_pretrained(path) |
| self.model.eval() |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| if self.device == "cuda": |
| self.model = self.model.to(self.device).half() |
| self._use_ort = False |
| self._torch = torch |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]: |
| inputs = data.get("inputs", data.get("input", "")) |
| if isinstance(inputs, str): |
| inputs = [inputs] |
|
|
| all_results = [] |
| for i in range(0, len(inputs), self.batch_size): |
| batch = inputs[i : i + self.batch_size] |
| encoded = self.tokenizer( |
| batch, padding=True, truncation=True, |
| max_length=128, return_tensors="pt" if not self._use_ort else "np", |
| ) |
|
|
| if self._use_ort: |
| logits = self.model(**{k: v for k, v in encoded.items()}).logits |
| if hasattr(logits, 'numpy'): |
| logits = logits.numpy() |
| exp = np.exp(logits - logits.max(axis=-1, keepdims=True)) |
| probs = exp / exp.sum(axis=-1, keepdims=True) |
| for j in range(len(batch)): |
| all_results.append([ |
| {"label": "LABEL_0", "score": round(float(probs[j][0]), 6)}, |
| {"label": "LABEL_1", "score": round(float(probs[j][1]), 6)}, |
| ]) |
| else: |
| torch = self._torch |
| encoded = {k: v.to(self.device) for k, v in encoded.items()} |
| with torch.no_grad(): |
| logits = self.model(**encoded).logits |
| probs = torch.softmax(logits, dim=-1) |
| for j in range(len(batch)): |
| all_results.append([ |
| {"label": "LABEL_0", "score": round(probs[j][0].item(), 6)}, |
| {"label": "LABEL_1", "score": round(probs[j][1].item(), 6)}, |
| ]) |
|
|
| return all_results |
|
|