| from typing import Dict, List, Any |
| import torch |
| from transformers import PegasusForConditionalGeneration, PegasusTokenizer |
| import re |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| """ |
| Initialize the endpoint handler with the model and tokenizer. |
| |
| :param path: Path to the model weights |
| """ |
| |
| self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
| |
| self.tokenizer = PegasusTokenizer.from_pretrained(path) |
| self.model = PegasusForConditionalGeneration.from_pretrained(path).to(self.torch_device) |
|
|
| def split_into_paragraphs(self, text: str) -> List[str]: |
| """ |
| Split text into paragraphs while preserving empty lines. |
| |
| :param text: Input text |
| :return: List of paragraphs |
| """ |
| paragraphs = text.split('\n\n') |
| return [p.strip() for p in paragraphs if p.strip()] |
|
|
| def split_into_sentences(self, paragraph: str) -> List[str]: |
| """ |
| Split paragraph into sentences using regex. |
| |
| :param paragraph: Input paragraph |
| :return: List of sentences |
| """ |
| sentences = re.split(r'(?<=[.!?])\s+', paragraph) |
| return [s.strip() for s in sentences if s.strip()] |
|
|
| def get_response(self, input_text: str, num_return_sequences: int = 1) -> str: |
| """ |
| Generate paraphrased text for a single input. |
| |
| :param input_text: Input sentence to paraphrase |
| :param num_return_sequences: Number of alternative paraphrases to generate |
| :return: Paraphrased text |
| """ |
| batch = self.tokenizer.prepare_seq2seq_batch( |
| [input_text], |
| truncation=True, |
| padding='longest', |
| max_length=80, |
| return_tensors="pt" |
| ).to(self.torch_device) |
|
|
| translated = self.model.generate( |
| **batch, |
| num_beams=10, |
| num_return_sequences=num_return_sequences, |
| temperature=1.0, |
| repetition_penalty=2.8, |
| length_penalty=1.2, |
| max_length=80, |
| min_length=5, |
| no_repeat_ngram_size=3 |
| ) |
|
|
| tgt_text = self.tokenizer.batch_decode(translated, skip_special_tokens=True) |
| return tgt_text[0] |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process the incoming request and generate paraphrased text. |
| |
| :param data: Request payload containing input text |
| :return: Paraphrased text |
| """ |
| |
| inputs = data.pop("inputs", data) |
| |
| |
| if not isinstance(inputs, str): |
| raise ValueError("Input must be a string") |
|
|
| |
| paragraphs = self.split_into_paragraphs(inputs) |
| paraphrased_paragraphs = [] |
|
|
| |
| for paragraph in paragraphs: |
| sentences = self.split_into_sentences(paragraph) |
| paraphrased_sentences = [] |
|
|
| for sentence in sentences: |
| |
| if len(sentence.split()) < 3: |
| paraphrased_sentences.append(sentence) |
| continue |
|
|
| try: |
| |
| paraphrased = self.get_response(sentence) |
| |
| |
| if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']): |
| paraphrased_sentences.append(paraphrased) |
| else: |
| paraphrased_sentences.append(sentence) |
| except Exception as e: |
| print(f"Error processing sentence: {e}") |
| paraphrased_sentences.append(sentence) |
|
|
| |
| paraphrased_paragraphs.append(' '.join(paraphrased_sentences)) |
|
|
| |
| return {"outputs": '\n\n'.join(paraphrased_paragraphs)} |