| from transformers import pipeline, Pipeline, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification |
| from transformers.pipelines import PIPELINE_REGISTRY |
| import torch |
|
|
|
|
| class SpanClassificationPipeline(Pipeline): |
| def __init__(self, model, tokenizer, device="cpu", **kwargs): |
| super().__init__(model=model, tokenizer=tokenizer, device=device, **kwargs) |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| def _sanitize_parameters(self, **kwargs): |
| return {}, kwargs, {} |
|
|
| def preprocess(self, inputs): |
| return self.tokenizer(inputs, return_tensors="pt").to(self.device) |
|
|
| def _forward(self, model_inputs): |
| with torch.no_grad(): |
| outputs = self.model(**model_inputs) |
| return outputs |
|
|
| def postprocess(self, model_outputs): |
| logits = model_outputs.logits |
| return int(torch.argmax(logits, dim=1).item()) |
|
|
|
|
| PIPELINE_REGISTRY.register_pipeline( |
| task="spancnn-classification", |
| pipeline_class=SpanClassificationPipeline, |
| pt_model=AutoModelForSequenceClassification, |
| ) |
|
|