| from typing import Dict, List, Any |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
| |
| model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| |
| self.pipeline = pipeline( |
| "text-generation", |
| model=model_name, |
| model_kwargs={"torch_dtype": torch.bfloat16}, |
| device_map="auto", |
| ) |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| input args: |
| data: a dict with elements... |
| inputs: List[List[Dict[str, str]]] or List[str] , inputs to batch-process in conversational format |
| parameters: Any , parameters to be passed into model |
| outputs: |
| list of {'generated_text': str} type outputs |
| """ |
| |
| inputs = data.pop("inputs", data) |
| parameters = data.pop("parameters", None) |
|
|
| |
| if parameters is not None: |
| predictions = self.pipeline(inputs, **parameters) |
| else: |
| predictions = self.pipeline(inputs) |
| |
| |
| results = [] |
| for e in predictions: |
| e_turn = e[0]["generated_text"][-1] |
| results.append({ |
| 'next_chat_turn': e_turn, |
| 'next_chat_text': e_turn['content'], |
| }) |
| return results |
|
|