| import gradio as gr |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| import torch |
|
|
| |
| model_name = "fohake/cert" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
| |
| def predict(text): |
| inputs = tokenizer(text, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits = outputs.logits |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) |
| predicted_class = torch.argmax(probabilities, dim=-1).item() |
| confidence = probabilities[0][predicted_class].item() |
| return {"class": predicted_class, "confidence": confidence} |
|
|
| |
| iface = gr.Interface( |
| fn=predict, |
| inputs=gr.inputs.Textbox(lines=2, placeholder="Enter text here..."), |
| outputs="json", |
| title="Text Classification with CERT", |
| description="Enter a piece of text to classify it using the CERT model." |
| ) |
|
|
| if __name__ == "__main__": |
| iface.launch() |
|
|