| import gradio as gr |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import torch |
|
|
|
|
| |
| MODEL_NAME = "comma-project/normalization-byt5-small" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) |
|
|
|
|
| def normalize_text(text: str) -> str: |
| """ |
| Normalize input text using ByT5. |
| """ |
|
|
| if not text.strip(): |
| return "" |
|
|
| |
| inputs = tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| padding=True, |
| max_length=1024, |
| ) |
|
|
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_length=1024, |
| num_beams=2, |
| early_stopping=True, |
| ) |
|
|
| |
| normalized = tokenizer.decode( |
| outputs[0], |
| skip_special_tokens=True, |
| ) |
|
|
| return normalized |
|
|
|
|
| |
| demo = gr.Interface( |
| fn=normalize_text, |
| inputs=gr.Textbox( |
| label="Input Text", |
| placeholder="Enter text to normalize...", |
| lines=4, |
| ), |
| outputs=gr.Textbox( |
| label="Normalized Text", |
| lines=4, |
| ), |
| title="Text Normalization with ByT5", |
| description="Normalize noisy or non-standard text using the ByT5 model.", |
| theme="soft", |
| examples=[ |
| ["Scͥbo uobiᷤᷤ ñ pauli ł donati."], |
| ["""⁊ pitie mlt' lelasce |
| P ities li dist. uai a ton peire |
| Nelaissier. """, """Uer̃ ab his qͥ ita dissert̃ |
| q̃ri debet. qͥd ꝑ amorem dei. quidq ꝑ amorẽ |
| boni tẽꝑalis ueluit intellig̾e."""] |
| ], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|