| |
| """Untitled0.ipynb |
| |
| Automatically generated by Colaboratory. |
| |
| Original file is located at |
| https://colab.research.google.com/drive/1aMkctyYgdHD61sv7-bJHFN1B5taCv6c2 |
| """ |
|
|
| import gradio as gr |
| from datasets import load_dataset |
| import evaluate |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer |
| import numpy as np |
| import nltk |
|
|
| nltk.download("punkt") |
| raw_dataset = load_dataset("scientific_papers", "pubmed") |
| metric = evaluate.load("rouge") |
| model_checkpoint = "t5-small" |
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) |
|
|
| if model_checkpoint in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]: |
| prefix = "summarize: " |
| else: |
| prefix = "" |
|
|
| |
| max_input_length = 512 |
| max_target_length = 128 |
| def preprocess_function(examples): |
| inputs = [prefix + doc for doc in examples["article"]] |
| model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True) |
|
|
| |
| |
| labels = tokenizer(text_target=examples["abstract"], max_length=max_target_length, truncation=True) |
|
|
| model_inputs["labels"] = labels["input_ids"] |
| return model_inputs |
|
|
| for split in ["train", "validation", "test"]: |
| raw_dataset[split] = raw_dataset[split].select([n for n in np.random.randint(0, len(raw_dataset[split]) - 1, 1_000)]) |
| tokenized_dataset = raw_dataset.map(preprocess_function, batched=True) |
|
|
|
|
| model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) |
|
|
| batch_size = 8 |
|
|
| args = Seq2SeqTrainingArguments( |
| f"{model_checkpoint}-scientific_papers", |
| evaluation_strategy="epoch", |
| learning_rate=2e-5, |
| per_device_train_batch_size=batch_size, |
| per_device_eval_batch_size=batch_size, |
| weight_decay=0.01, |
| save_total_limit=3, |
| num_train_epochs=1, |
| predict_with_generate=True, |
| |
| push_to_hub=False, |
| gradient_accumulation_steps=2 |
| ) |
|
|
| data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) |
|
|
| |
| def compute_metrics(eval_pred): |
| predictions, labels = eval_pred |
| decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) |
| |
| labels = np.where(labels != -100, labels, tokenizer.pad_token_id) |
| decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) |
| |
| decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds] |
| decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels] |
| result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) |
| |
| result = {key: value * 100 for key, value in result.items()} |
| |
| prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions] |
| result["gen_len"] = np.mean(prediction_lens) |
| return {k: round(v, 4) for k, v in result.items()} |
|
|
| trainer = Seq2SeqTrainer( |
| model, |
| args, |
| train_dataset=tokenized_dataset["train"], |
| eval_dataset=tokenized_dataset["validation"], |
| data_collator=data_collator, |
| tokenizer=tokenizer, |
| compute_metrics=compute_metrics |
| ) |
| trainer.train() |
|
|
| |
| import gradio as gr |
|
|
| def summarizer(input_text): |
| inputs = [prefix + input_text] |
| model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt") |
| summary_ids = model.generate( |
| input_ids=model_inputs["input_ids"], |
| attention_mask=model_inputs["attention_mask"], |
| num_beams=4, |
| length_penalty=2.0, |
| max_length=max_target_length + 2, |
| repetition_penalty=2.0, |
| early_stopping=True, |
| use_cache=True |
| ) |
| summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
| return summary |
|
|
| |
| iface = gr.Interface( |
| fn=summarizer, |
| inputs=gr.inputs.Textbox(label="Input Text"), |
| outputs=gr.outputs.Textbox(label="Summary"), |
| title="Scientific Paper Summarizer", |
| description="Summarizes scientific papers using a fine-tuned T5 model", |
| theme="gray" |
| ) |
| iface.launch() |