BERT-OJA-SkillLess / handler.py
mpalinski's picture
Upload handler.py with huggingface_hub
7eb969d verified
"""
Custom handler for BERT-OJA-SkillLess on HF Inference Endpoints.
Uses ONNX Runtime with CUDA for 2-4x faster inference.
"""
from typing import Dict, List, Any
import numpy as np
from transformers import AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.batch_size = 1024
try:
from optimum.onnxruntime import ORTModelForSequenceClassification
self.model = ORTModelForSequenceClassification.from_pretrained(
path, export=True, provider="CUDAExecutionProvider",
)
self._use_ort = True
print(f"[handler] Loaded ONNX model on CUDA (batch_size={self.batch_size})")
except Exception as e:
print(f"[handler] ONNX failed ({e}), falling back to PyTorch FP16")
import torch
from transformers import AutoModelForSequenceClassification
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.model.eval()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if self.device == "cuda":
self.model = self.model.to(self.device).half()
self._use_ort = False
self._torch = torch
def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]:
inputs = data.get("inputs", data.get("input", ""))
if isinstance(inputs, str):
inputs = [inputs]
all_results = []
for i in range(0, len(inputs), self.batch_size):
batch = inputs[i : i + self.batch_size]
encoded = self.tokenizer(
batch, padding=True, truncation=True,
max_length=128, return_tensors="pt" if not self._use_ort else "np",
)
if self._use_ort:
logits = self.model(**{k: v for k, v in encoded.items()}).logits
if hasattr(logits, 'numpy'):
logits = logits.numpy()
exp = np.exp(logits - logits.max(axis=-1, keepdims=True))
probs = exp / exp.sum(axis=-1, keepdims=True)
for j in range(len(batch)):
all_results.append([
{"label": "LABEL_0", "score": round(float(probs[j][0]), 6)},
{"label": "LABEL_1", "score": round(float(probs[j][1]), 6)},
])
else:
torch = self._torch
encoded = {k: v.to(self.device) for k, v in encoded.items()}
with torch.no_grad():
logits = self.model(**encoded).logits
probs = torch.softmax(logits, dim=-1)
for j in range(len(batch)):
all_results.append([
{"label": "LABEL_0", "score": round(probs[j][0].item(), 6)},
{"label": "LABEL_1", "score": round(probs[j][1].item(), 6)},
])
return all_results