import gradio as gr from gliner import GLiNER import torch # Load the model device = "cuda" if torch.cuda.is_available() else "cpu" model = GLiNER.from_pretrained("Rishi2455/gliner-floatbot-ai").to(device) def extract_entities(text, labels, threshold): """Extract named entities from text using GLiNER.""" if not text or not labels: return {"entities": []} # Split comma-separated labels label_list = [label.strip().lower() for label in labels.split(",") if label.strip()] # Perform prediction entities = model.predict_entities(text, label_list, threshold=threshold) # Format results results = [] for entity in entities: results.append({ "text": entity["text"], "label": entity["label"], "score": round(entity["score"], 4) }) return {"entities": results} # Build Gradio interface demo = gr.Interface( fn=extract_entities, inputs=[ gr.Textbox(label="Text", placeholder="Enter text to analyze...", lines=5), gr.Textbox(label="Labels (comma-separated)", placeholder="person, organization, location"), gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold"), ], outputs=gr.JSON(label="Extracted Entities"), title="GLiNER Named Entity Recognition", description="Extract named entities from text using GLiNER. Enter your text, specify entity labels (comma-separated), and adjust the confidence threshold.", examples=[ ["Elon Musk founded SpaceX in Hawthorne, California.", "person, organization, location", 0.4], ["Apple released the iPhone 15 in September 2023.", "company, product, date", 0.4], ], ) if __name__ == "__main__": demo.launch()