| import torch |
| import gradio as gr |
| from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig |
|
|
| |
| model_names = [ |
| "google/bigbird-pegasus-large-arxiv", |
| "facebook/bart-large-cnn", |
| "google/t5-v1_1-large", |
| "sshleifer/distilbart-cnn-12-6", |
| "allenai/led-base-16384", |
| "google/pegasus-xsum", |
| "togethercomputer/LLaMA-2-7B-32K" |
| ] |
|
|
| |
| summarizer = None |
| tokenizer = None |
| max_tokens = None |
|
|
| |
| example_text = ( |
| "Artificial intelligence (AI) is intelligence—perceiving, synthesizing, and inferring information—" |
| "demonstrated by machines, as opposed to intelligence displayed by non-human animals and humans. " |
| "Example tasks in which AI is employed include speech recognition, computer vision, language translation, " |
| "autonomous vehicles, and game playing. AI research has been defined as the field of study of intelligent " |
| "agents, which refers to any system that perceives its environment and takes actions that maximize its " |
| "chance of achieving its goals." |
| ) |
|
|
| |
| def load_model(model_name): |
| global summarizer, tokenizer, max_tokens |
| try: |
| |
| summarizer = pipeline("summarization", model=model_name, torch_dtype=torch.float32) |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| config = AutoConfig.from_pretrained(model_name) |
|
|
| |
| max_tokens = getattr(config, 'max_position_embeddings', 1024) |
|
|
| return f"Model {model_name} loaded successfully! Max tokens: {max_tokens}" |
| except Exception as e: |
| return f"Failed to load model {model_name}. Error: {str(e)}" |
|
|
| |
| def summarize_text(input, min_length, max_length): |
| if summarizer is None: |
| return "No model loaded!" |
|
|
| try: |
| |
| input_tokens = tokenizer.encode(input, return_tensors="pt") |
| num_tokens = input_tokens.shape[1] |
| if num_tokens > max_tokens: |
| return f"Error: Input exceeds the max token limit of {max_tokens}." |
|
|
| |
| min_summary_length = max(10, int(num_tokens * (min_length / 100))) |
| max_summary_length = min(max_tokens, int(num_tokens * (max_length / 100))) |
|
|
| |
| output = summarizer(input, min_length=min_summary_length, max_length=max_summary_length, truncation=True) |
| return output[0]['summary_text'] |
| except Exception as e: |
| return f"Summarization failed: {str(e)}" |
|
|
| |
| with gr.Blocks() as demo: |
| with gr.Row(): |
| model_dropdown = gr.Dropdown(choices=model_names, label="Choose a model", value="sshleifer/distilbart-cnn-12-6") |
| load_button = gr.Button("Load Model") |
|
|
| load_message = gr.Textbox(label="Load Status", interactive=False) |
|
|
| min_length_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Minimum Summary Length (%)", value=10) |
| max_length_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Maximum Summary Length (%)", value=20) |
|
|
| input_text = gr.Textbox(label="Input text to summarize", lines=6, value=example_text) |
| summarize_button = gr.Button("Summarize Text") |
| output_text = gr.Textbox(label="Summarized text", lines=4) |
|
|
| load_button.click(fn=load_model, inputs=model_dropdown, outputs=load_message) |
| summarize_button.click(fn=summarize_text, inputs=[input_text, min_length_slider, max_length_slider], |
| outputs=output_text) |
|
|
| demo.launch() |
|
|