| from transformers import Pipeline, PreTrainedTokenizer, AutoTokenizer |
| from typing import Dict, Union, List |
| import torch |
|
|
| class TokenizerPipeline(Pipeline): |
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
| def _sanitize_parameters(self, **kwargs): |
| |
| preprocess_kwargs = {} |
| if "padding" in kwargs: |
| preprocess_kwargs["padding"] = kwargs["padding"] |
| if "truncation" in kwargs: |
| preprocess_kwargs["truncation"] = kwargs["truncation"] |
| |
| postprocess_kwargs = {} |
| if "return_tokens" in kwargs: |
| postprocess_kwargs["return_tokens"] = kwargs["return_tokens"] |
| |
| return preprocess_kwargs, {}, postprocess_kwargs |
|
|
| def preprocess(self, inputs, **kwargs) -> Dict: |
| |
| return self.tokenizer(inputs, return_tensors="pt", **kwargs) |
|
|
| def _forward(self, inputs) -> Dict: |
| |
| return inputs |
|
|
| def postprocess(self, model_outputs, **kwargs) -> Dict: |
| |
| input_ids = model_outputs["input_ids"][0] |
| |
| if kwargs.get("return_tokens", True): |
| tokens = self.tokenizer.convert_ids_to_tokens(input_ids) |
| return {"tokens": tokens} |
| else: |
| return {"input_ids": input_ids.tolist()} |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(".") |
| pipeline = TokenizerPipeline(tokenizer=tokenizer) |
|
|
| |
| def get_pipeline() -> Pipeline: |
| return pipeline |
|
|