| from typing import Dict, Any, List, Union |
|
|
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForSequenceClassification, |
| TextClassificationPipeline, |
| ) |
|
|
|
|
| class EndpointHandler: |
| """Custom handler for Hugging Face Inference Endpoints. |
| |
| Loads a fine-tuned text-classification model and exposes a callable |
| that the endpoint runtime will invoke. The runtime will instantiate |
| this class once at startup, passing the model directory path. |
| """ |
|
|
| def __init__(self, path: str = "", **kwargs): |
| |
| |
| model_dir = path or "." |
|
|
| |
| self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True) |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_dir) |
|
|
| |
| self.pipeline = TextClassificationPipeline( |
| model=self.model, |
| tokenizer=self.tokenizer, |
| device=-1, |
| return_all_scores=False, |
| function_to_apply="sigmoid" |
| if getattr(self.model.config, "problem_type", None) |
| == "multi_label_classification" |
| else "softmax", |
| ) |
|
|
| def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]: |
| """Run inference on the incoming request. |
| |
| Expected input format from the Inference Endpoint runtime: |
| { |
| "inputs": "some text" | ["text 1", "text 2", ...], |
| "parameters": { ... } # optional pipeline kwargs (e.g., top_k) |
| } |
| """ |
| |
| raw_inputs = data.get("inputs", data) |
| if isinstance(raw_inputs, str): |
| raw_inputs = [raw_inputs] |
|
|
| |
| parameters = data.get("parameters", {}) |
|
|
| |
| outputs = self.pipeline(raw_inputs, **parameters) |
|
|
| |
| if len(outputs) == 1: |
| return outputs[0] |
| return outputs |