Rishi2455's picture
Update app.py
3270eb3 verified
import gradio as gr
from gliner import GLiNER
import torch
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GLiNER.from_pretrained("Rishi2455/gliner-floatbot-ai").to(device)
def extract_entities(text, labels, threshold):
"""Extract named entities from text using GLiNER."""
if not text or not labels:
return {"entities": []}
# Split comma-separated labels
label_list = [label.strip().lower() for label in labels.split(",") if label.strip()]
# Perform prediction
entities = model.predict_entities(text, label_list, threshold=threshold)
# Format results
results = []
for entity in entities:
results.append({
"text": entity["text"],
"label": entity["label"],
"score": round(entity["score"], 4)
})
return {"entities": results}
# Build Gradio interface
demo = gr.Interface(
fn=extract_entities,
inputs=[
gr.Textbox(label="Text", placeholder="Enter text to analyze...", lines=5),
gr.Textbox(label="Labels (comma-separated)", placeholder="person, organization, location"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold"),
],
outputs=gr.JSON(label="Extracted Entities"),
title="GLiNER Named Entity Recognition",
description="Extract named entities from text using GLiNER. Enter your text, specify entity labels (comma-separated), and adjust the confidence threshold.",
examples=[
["Elon Musk founded SpaceX in Hawthorne, California.", "person, organization, location", 0.4],
["Apple released the iPhone 15 in September 2023.", "company, product, date", 0.4],
],
)
if __name__ == "__main__":
demo.launch()