| import gradio as gr |
| import torch |
| from transformers import PegasusForConditionalGeneration, PegasusTokenizer |
| import re |
| import os |
|
|
| def load_model(): |
| """Load the model from local storage""" |
| torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f"Using device: {torch_device}") |
|
|
| |
| tokenizer = PegasusTokenizer.from_pretrained('./models') |
| model = PegasusForConditionalGeneration.from_pretrained('./models').to(torch_device) |
| return tokenizer, model, torch_device |
|
|
| def split_into_paragraphs(text): |
| """Split text into paragraphs while preserving empty lines.""" |
| paragraphs = text.split('\n\n') |
| return [p.strip() for p in paragraphs if p.strip()] |
|
|
| def split_into_sentences(paragraph): |
| """Split paragraph into sentences using regex.""" |
| sentences = re.split(r'(?<=[.!?])\s+', paragraph) |
| return [s.strip() for s in sentences if s.strip()] |
|
|
| def get_response(input_text, num_return_sequences, tokenizer, model, torch_device): |
| batch = tokenizer.prepare_seq2seq_batch( |
| [input_text], |
| truncation=True, |
| padding='longest', |
| max_length=80, |
| return_tensors="pt" |
| ).to(torch_device) |
|
|
| translated = model.generate( |
| **batch, |
| num_beams=10, |
| num_return_sequences=num_return_sequences, |
| temperature=1.0, |
| repetition_penalty=2.8, |
| length_penalty=1.2, |
| max_length=80, |
| min_length=5, |
| no_repeat_ngram_size=3 |
| ) |
|
|
| tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) |
| return tgt_text[0] |
|
|
| def get_response_from_text(context, tokenizer, model, torch_device): |
| """Process entire text while preserving paragraph structure.""" |
| paragraphs = split_into_paragraphs(context) |
| paraphrased_paragraphs = [] |
|
|
| for paragraph in paragraphs: |
| sentences = split_into_sentences(paragraph) |
| paraphrased_sentences = [] |
|
|
| for sentence in sentences: |
| if len(sentence.split()) < 3: |
| paraphrased_sentences.append(sentence) |
| continue |
|
|
| try: |
| paraphrased = get_response(sentence, 1, tokenizer, model, torch_device) |
| if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']): |
| paraphrased_sentences.append(paraphrased) |
| else: |
| paraphrased_sentences.append(sentence) |
| except Exception as e: |
| print(f"Error processing sentence: {e}") |
| paraphrased_sentences.append(sentence) |
|
|
| paraphrased_paragraphs.append(' '.join(paraphrased_sentences)) |
|
|
| return '\n\n'.join(paraphrased_paragraphs) |
|
|
| def create_interface(): |
| """Create and configure the Gradio interface""" |
| |
| tokenizer, model, torch_device = load_model() |
|
|
| def greet(context): |
| return get_response_from_text(context, tokenizer, model, torch_device) |
|
|
| |
| iface = gr.Interface( |
| fn=greet, |
| inputs=gr.Textbox( |
| lines=15, |
| label="Input Text", |
| placeholder="Enter your text here...", |
| elem_classes="input-text" |
| ), |
| outputs=gr.Textbox( |
| lines=15, |
| label="Paraphrased Text", |
| elem_classes="output-text" |
| ), |
| title="Advanced Text Paraphraser", |
| description="Enter text to generate a high-quality paraphrased version while maintaining paragraph structure.", |
| theme="default", |
| css=""" |
| .input-text, .output-text { |
| font-size: 16px !important; |
| font-family: Arial, sans-serif !important; |
| min-height: 300px !important; |
| } |
| """ |
| ) |
| return iface |
|
|
| if __name__ == "__main__": |
| |
| interface = create_interface() |
| interface.launch() |