| from transformers import Pipeline |
|
|
| class MyPipeline(Pipeline): |
| def _sanitize_parameters(self, **kwargs): |
| preprocess_kwargs = {} |
| if "max_length" in kwargs: |
| preprocess_kwargs["max_length"] = kwargs["max_length"] |
| if "num_beams" in kwargs: |
| preprocess_kwargs["num_beams"] = kwargs["num_beams"] |
|
|
| return preprocess_kwargs, {}, {} |
| def preprocess(self, inputs, **kwargs): |
| inputs = re.sub(r'[^A-Za-z가-힣,<>0-9:&# ]', '', inputs) |
| inputs = "질문 생성: <unused0>"+inputs |
| |
| input_ids = [tokenizer.bos_token_id] + tokenizer.encode(inputs) + [tokenizer.eos_token_id] |
| return {"inputs":torch.tensor([input_ids]),'max_length':kwargs['max_length'],'num_beams':kwargs['num_beams'] } |
|
|
| def _forward(self, model_inputs): |
| res_ids = model.generate( |
| model_inputs['inputs'], |
| max_length=model_inputs['max_length'], |
| num_beams=model_inputs['num_beams'], |
| eos_token_id=tokenizer.eos_token_id, |
| bad_words_ids=[[tokenizer.unk_token_id]] |
| ) |
| return {"logits": res_ids} |
|
|
| def postprocess(self, model_outputs): |
| a = tokenizer.batch_decode(model_outputs["logits"].tolist())[0] |
| out_question = a.replace('<s>', '').replace('</s>', '') |
| return out_question |
|
|
| def _inference(self,paragraph,**kwargs): |
| input_ids = self.preprocess(paragraph,**kwargs) |
| reds_ids = self._forward(input_ids) |
| out_question = self.postprocess(reds_ids) |
| return out_question |
|
|
| def make_question(self, text, **kwargs): |
| words = text.split(" ") |
| frame_size = kwargs['frame_size'] |
| hop_length = kwargs['hop_length'] |
| steps = round((len(words)-frame_size)/hop_length) + 1 |
| outs = [] |
| for step in range(steps): |
| try: |
| script = " ".join(words[step*hop_length:step*hop_length+frame_size]) |
| except: |
| script = " ".join(words[(1+step)*hop_length:]) |
| |
| outs.append(self._inference(script,**kwargs)) |
| |
| |
| return outs |