mou11's picture
Create app.py
1d8427d verified
import os
import json
import requests
import gradio as gr
from typing import TypedDict, Annotated, List
from langchain_groq import ChatGroq
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
def create_agent(api_key: str):
llm = ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0.1,
api_key=api_key
)
class AgentState(TypedDict):
messages: Annotated[list, add_messages]
patient_input: str
search_query: str
pubmed_results: List[dict]
clinical_recommendation: dict
conversation_history: List[dict]
def search_pubmed(query: str, max_results: int = 5) -> List[dict]:
search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
search_params = {"db": "pubmed", "term": query, "retmax": max_results, "retmode": "json", "sort": "relevance"}
search_response = requests.get(search_url, params=search_params)
search_data = search_response.json()
article_ids = search_data["esearchresult"]["idlist"]
if not article_ids:
return [{"title": "No results found", "abstract": "No relevant articles found.", "authors": "", "year": "", "url": ""}]
fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
fetch_params = {"db": "pubmed", "id": ",".join(article_ids), "retmode": "xml", "rettype": "abstract"}
fetch_response = requests.get(fetch_url, params=fetch_params)
from bs4 import BeautifulSoup
soup = BeautifulSoup(fetch_response.text, "xml")
articles = []
for article in soup.find_all("PubmedArticle"):
try:
title = article.find("ArticleTitle")
title = title.text if title else "No title"
abstract = article.find("AbstractText")
abstract = abstract.text if abstract else "No abstract available"
year = article.find("PubDate")
year = year.find("Year").text if year and year.find("Year") else "Unknown"
authors = article.find_all("LastName")
authors = ", ".join([a.text for a in authors[:3]]) + " et al." if authors else "Unknown"
pmid = article.find("PMID")
pmid = pmid.text if pmid else ""
articles.append({
"title": title,
"abstract": abstract[:500] + "..." if len(abstract) > 500 else abstract,
"authors": authors, "year": year, "pmid": pmid,
"url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/"
})
except:
continue
return articles if articles else [{"title": "Parse error", "abstract": "Could not parse results.", "authors": "", "year": "", "url": ""}]
def extract_search_query(state: AgentState) -> AgentState:
prompt = f"""You are a medical AI assistant. Generate the best PubMed search query for these symptoms.
Patient input: {state['patient_input']}
Return ONLY the search query (max 8 words), nothing else."""
response = llm.invoke([HumanMessage(content=prompt)])
return {**state, "search_query": response.content.strip()}
def search_medical_literature(state: AgentState) -> AgentState:
results = search_pubmed(state["search_query"], max_results=5)
return {**state, "pubmed_results": results}
def generate_recommendation(state: AgentState) -> AgentState:
articles_text = ""
for i, article in enumerate(state["pubmed_results"]):
articles_text += f"\nArticle {i+1}:\nTitle: {article['title']}\nAuthors: {article['authors']} ({article['year']})\nAbstract: {article['abstract']}\nURL: {article['url']}\n"
prompt = f"""You are an expert clinical decision support AI.
Patient symptoms: {state['patient_input']}
Relevant Medical Literature: {articles_text}
Generate a structured response in this EXACT JSON format:
{{
"possible_conditions": ["condition1", "condition2", "condition3"],
"recommended_tests": ["test1", "test2", "test3"],
"treatment_considerations": ["consideration1", "consideration2"],
"urgency_level": "Low/Medium/High/Emergency",
"reasoning": "Brief explanation",
"important_disclaimer": "This is AI-generated information for educational purposes only. Always consult a qualified healthcare professional.",
"sources": ["Article title - Authors (Year)"]
}}
Return ONLY the JSON object."""
response = llm.invoke([HumanMessage(content=prompt)])
try:
clean = response.content.strip()
if "```json" in clean:
clean = clean.split("```json")[1].split("```")[0].strip()
elif "```" in clean:
clean = clean.split("```")[1].split("```")[0].strip()
recommendation = json.loads(clean)
except:
recommendation = {
"possible_conditions": ["Unable to parse"],
"recommended_tests": [],
"treatment_considerations": [],
"urgency_level": "Unknown",
"reasoning": response.content,
"important_disclaimer": "Always consult a qualified healthcare professional.",
"sources": []
}
return {**state, "clinical_recommendation": recommendation}
def format_response(state: AgentState) -> AgentState:
history = state.get("conversation_history", [])
history.append({"patient_input": state["patient_input"], "recommendation": state["clinical_recommendation"]})
return {**state, "conversation_history": history}
graph = StateGraph(AgentState)
graph.add_node("extract_query", extract_search_query)
graph.add_node("search_pubmed", search_medical_literature)
graph.add_node("generate_recommendation", generate_recommendation)
graph.add_node("format_response", format_response)
graph.set_entry_point("extract_query")
graph.add_edge("extract_query", "search_pubmed")
graph.add_edge("search_pubmed", "generate_recommendation")
graph.add_edge("generate_recommendation", "format_response")
graph.add_edge("format_response", END)
return graph.compile(), llm
def run_agent(agent, patient_input: str, history: list) -> tuple:
initial_state = {
"messages": [],
"patient_input": patient_input,
"search_query": "",
"pubmed_results": [],
"clinical_recommendation": {},
"conversation_history": history
}
result = agent.invoke(initial_state)
rec = result["clinical_recommendation"]
output = f"""🚨 URGENCY LEVEL: {rec.get('urgency_level', 'Unknown')}
πŸ”¬ POSSIBLE CONDITIONS:
{chr(10).join([f"β€’ {c}" for c in rec.get('possible_conditions', [])])}
πŸ§ͺ RECOMMENDED TESTS:
{chr(10).join([f"β€’ {t}" for t in rec.get('recommended_tests', [])])}
πŸ’Š TREATMENT CONSIDERATIONS:
{chr(10).join([f"β€’ {t}" for t in rec.get('treatment_considerations', [])])}
🧠 CLINICAL REASONING:
{rec.get('reasoning', '')}
πŸ“š SOURCES FROM PUBMED:
{chr(10).join([f"[{i+1}] {s}" for i, s in enumerate(rec.get('sources', []))])}
⚠️ DISCLAIMER: {rec.get('important_disclaimer', '')}"""
return output, result["conversation_history"]
with gr.Blocks(title="Clinical Decision Support Agent") as demo:
gr.Markdown("# πŸ₯ Clinical Decision Support Agent")
gr.Markdown("Powered by LangGraph + LLaMA 3.3 70B + Real PubMed Literature")
gr.Markdown("⚠️ **For educational purposes only. Always consult a qualified healthcare professional.**")
with gr.Row():
api_key_input = gr.Textbox(
label="πŸ”‘ Enter your Groq API Key",
placeholder="gsk_xxxxxxxxxxxx",
type="password"
)
history_state = gr.State([])
agent_state = gr.State(None)
def initialize_agent(api_key):
if not api_key.strip():
return None, "❌ Please enter a valid Groq API key"
try:
agent, _ = create_agent(api_key)
return agent, "βœ… Agent initialized successfully!"
except Exception as e:
return None, f"❌ Error: {str(e)}"
init_btn = gr.Button("πŸš€ Initialize Agent", variant="primary")
init_status = gr.Textbox(label="Status", interactive=False)
init_btn.click(
fn=initialize_agent,
inputs=[api_key_input],
outputs=[agent_state, init_status]
)
with gr.Row():
with gr.Column():
symptom_input = gr.Textbox(
label="Describe Patient Symptoms",
placeholder="Example: I have fever of 39Β°C for 3 days, cough with yellow sputum, chest pain...",
lines=4
)
submit_btn = gr.Button("πŸ” Analyze Symptoms", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear Conversation")
with gr.Column():
output_text = gr.Textbox(
label="Clinical Recommendation",
lines=20,
interactive=False
)
def analyze(agent, symptoms, history):
if agent is None:
return "❌ Please initialize the agent first with your Groq API key.", history
if not symptoms.strip():
return "Please describe your symptoms.", history
output, new_history = run_agent(agent, symptoms, history)
return output, new_history
submit_btn.click(
fn=analyze,
inputs=[agent_state, symptom_input, history_state],
outputs=[output_text, history_state]
)
clear_btn.click(
fn=lambda: ([], ""),
outputs=[history_state, output_text]
)
demo.launch()