import os import json import requests import gradio as gr from typing import TypedDict, Annotated, List from langchain_groq import ChatGroq from langchain_core.messages import HumanMessage from langgraph.graph import StateGraph, END from langgraph.graph.message import add_messages def create_agent(api_key: str): llm = ChatGroq( model="llama-3.3-70b-versatile", temperature=0.1, api_key=api_key ) class AgentState(TypedDict): messages: Annotated[list, add_messages] patient_input: str search_query: str pubmed_results: List[dict] clinical_recommendation: dict conversation_history: List[dict] def search_pubmed(query: str, max_results: int = 5) -> List[dict]: search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" search_params = {"db": "pubmed", "term": query, "retmax": max_results, "retmode": "json", "sort": "relevance"} search_response = requests.get(search_url, params=search_params) search_data = search_response.json() article_ids = search_data["esearchresult"]["idlist"] if not article_ids: return [{"title": "No results found", "abstract": "No relevant articles found.", "authors": "", "year": "", "url": ""}] fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" fetch_params = {"db": "pubmed", "id": ",".join(article_ids), "retmode": "xml", "rettype": "abstract"} fetch_response = requests.get(fetch_url, params=fetch_params) from bs4 import BeautifulSoup soup = BeautifulSoup(fetch_response.text, "xml") articles = [] for article in soup.find_all("PubmedArticle"): try: title = article.find("ArticleTitle") title = title.text if title else "No title" abstract = article.find("AbstractText") abstract = abstract.text if abstract else "No abstract available" year = article.find("PubDate") year = year.find("Year").text if year and year.find("Year") else "Unknown" authors = article.find_all("LastName") authors = ", ".join([a.text for a in authors[:3]]) + " et al." if authors else "Unknown" pmid = article.find("PMID") pmid = pmid.text if pmid else "" articles.append({ "title": title, "abstract": abstract[:500] + "..." if len(abstract) > 500 else abstract, "authors": authors, "year": year, "pmid": pmid, "url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/" }) except: continue return articles if articles else [{"title": "Parse error", "abstract": "Could not parse results.", "authors": "", "year": "", "url": ""}] def extract_search_query(state: AgentState) -> AgentState: prompt = f"""You are a medical AI assistant. Generate the best PubMed search query for these symptoms. Patient input: {state['patient_input']} Return ONLY the search query (max 8 words), nothing else.""" response = llm.invoke([HumanMessage(content=prompt)]) return {**state, "search_query": response.content.strip()} def search_medical_literature(state: AgentState) -> AgentState: results = search_pubmed(state["search_query"], max_results=5) return {**state, "pubmed_results": results} def generate_recommendation(state: AgentState) -> AgentState: articles_text = "" for i, article in enumerate(state["pubmed_results"]): articles_text += f"\nArticle {i+1}:\nTitle: {article['title']}\nAuthors: {article['authors']} ({article['year']})\nAbstract: {article['abstract']}\nURL: {article['url']}\n" prompt = f"""You are an expert clinical decision support AI. Patient symptoms: {state['patient_input']} Relevant Medical Literature: {articles_text} Generate a structured response in this EXACT JSON format: {{ "possible_conditions": ["condition1", "condition2", "condition3"], "recommended_tests": ["test1", "test2", "test3"], "treatment_considerations": ["consideration1", "consideration2"], "urgency_level": "Low/Medium/High/Emergency", "reasoning": "Brief explanation", "important_disclaimer": "This is AI-generated information for educational purposes only. Always consult a qualified healthcare professional.", "sources": ["Article title - Authors (Year)"] }} Return ONLY the JSON object.""" response = llm.invoke([HumanMessage(content=prompt)]) try: clean = response.content.strip() if "```json" in clean: clean = clean.split("```json")[1].split("```")[0].strip() elif "```" in clean: clean = clean.split("```")[1].split("```")[0].strip() recommendation = json.loads(clean) except: recommendation = { "possible_conditions": ["Unable to parse"], "recommended_tests": [], "treatment_considerations": [], "urgency_level": "Unknown", "reasoning": response.content, "important_disclaimer": "Always consult a qualified healthcare professional.", "sources": [] } return {**state, "clinical_recommendation": recommendation} def format_response(state: AgentState) -> AgentState: history = state.get("conversation_history", []) history.append({"patient_input": state["patient_input"], "recommendation": state["clinical_recommendation"]}) return {**state, "conversation_history": history} graph = StateGraph(AgentState) graph.add_node("extract_query", extract_search_query) graph.add_node("search_pubmed", search_medical_literature) graph.add_node("generate_recommendation", generate_recommendation) graph.add_node("format_response", format_response) graph.set_entry_point("extract_query") graph.add_edge("extract_query", "search_pubmed") graph.add_edge("search_pubmed", "generate_recommendation") graph.add_edge("generate_recommendation", "format_response") graph.add_edge("format_response", END) return graph.compile(), llm def run_agent(agent, patient_input: str, history: list) -> tuple: initial_state = { "messages": [], "patient_input": patient_input, "search_query": "", "pubmed_results": [], "clinical_recommendation": {}, "conversation_history": history } result = agent.invoke(initial_state) rec = result["clinical_recommendation"] output = f"""๐Ÿšจ URGENCY LEVEL: {rec.get('urgency_level', 'Unknown')} ๐Ÿ”ฌ POSSIBLE CONDITIONS: {chr(10).join([f"โ€ข {c}" for c in rec.get('possible_conditions', [])])} ๐Ÿงช RECOMMENDED TESTS: {chr(10).join([f"โ€ข {t}" for t in rec.get('recommended_tests', [])])} ๐Ÿ’Š TREATMENT CONSIDERATIONS: {chr(10).join([f"โ€ข {t}" for t in rec.get('treatment_considerations', [])])} ๐Ÿง  CLINICAL REASONING: {rec.get('reasoning', '')} ๐Ÿ“š SOURCES FROM PUBMED: {chr(10).join([f"[{i+1}] {s}" for i, s in enumerate(rec.get('sources', []))])} โš ๏ธ DISCLAIMER: {rec.get('important_disclaimer', '')}""" return output, result["conversation_history"] with gr.Blocks(title="Clinical Decision Support Agent") as demo: gr.Markdown("# ๐Ÿฅ Clinical Decision Support Agent") gr.Markdown("Powered by LangGraph + LLaMA 3.3 70B + Real PubMed Literature") gr.Markdown("โš ๏ธ **For educational purposes only. Always consult a qualified healthcare professional.**") with gr.Row(): api_key_input = gr.Textbox( label="๐Ÿ”‘ Enter your Groq API Key", placeholder="gsk_xxxxxxxxxxxx", type="password" ) history_state = gr.State([]) agent_state = gr.State(None) def initialize_agent(api_key): if not api_key.strip(): return None, "โŒ Please enter a valid Groq API key" try: agent, _ = create_agent(api_key) return agent, "โœ… Agent initialized successfully!" except Exception as e: return None, f"โŒ Error: {str(e)}" init_btn = gr.Button("๐Ÿš€ Initialize Agent", variant="primary") init_status = gr.Textbox(label="Status", interactive=False) init_btn.click( fn=initialize_agent, inputs=[api_key_input], outputs=[agent_state, init_status] ) with gr.Row(): with gr.Column(): symptom_input = gr.Textbox( label="Describe Patient Symptoms", placeholder="Example: I have fever of 39ยฐC for 3 days, cough with yellow sputum, chest pain...", lines=4 ) submit_btn = gr.Button("๐Ÿ” Analyze Symptoms", variant="primary") clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Conversation") with gr.Column(): output_text = gr.Textbox( label="Clinical Recommendation", lines=20, interactive=False ) def analyze(agent, symptoms, history): if agent is None: return "โŒ Please initialize the agent first with your Groq API key.", history if not symptoms.strip(): return "Please describe your symptoms.", history output, new_history = run_agent(agent, symptoms, history) return output, new_history submit_btn.click( fn=analyze, inputs=[agent_state, symptom_input, history_state], outputs=[output_text, history_state] ) clear_btn.click( fn=lambda: ([], ""), outputs=[history_state, output_text] ) demo.launch()