azlaan428 commited on
Commit ·
d62c791
1
Parent(s): 1ef27ba
feat: PRISMA filter, follow-up questions, rate limit retry, staggered post-pipeline calls
Browse files- agent/agent.py +70 -10
- app.py +44 -14
- sessions.json +0 -0
- templates/index.html +3 -3
agent/agent.py
CHANGED
|
@@ -9,6 +9,7 @@ from retrieval.pubmed import fetch_pubmed
|
|
| 9 |
|
| 10 |
|
| 11 |
def get_llm():
|
|
|
|
| 12 |
return ChatGroq(
|
| 13 |
model="llama-3.1-8b-instant",
|
| 14 |
temperature=0,
|
|
@@ -16,6 +17,21 @@ def get_llm():
|
|
| 16 |
)
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
@tool
|
| 20 |
def PubMedSearch(query: str) -> str:
|
| 21 |
"""Searches PubMed for biomedical literature abstracts."""
|
|
@@ -37,7 +53,7 @@ def run_query_architect(user_question):
|
|
| 37 |
"Return ONLY a numbered list 1-5, one query per line, no explanations.\n\n"
|
| 38 |
"Question: " + user_question
|
| 39 |
)
|
| 40 |
-
response =
|
| 41 |
raw_lines = response.content.strip().split("\n")
|
| 42 |
queries = []
|
| 43 |
for line in raw_lines:
|
|
@@ -87,7 +103,7 @@ def run_evidence_synthesiser(user_question, papers):
|
|
| 87 |
"Retrieved Literature:\n" + corpus + "\n\n"
|
| 88 |
"Be precise and cite PMIDs throughout."
|
| 89 |
)
|
| 90 |
-
response =
|
| 91 |
return response.content
|
| 92 |
|
| 93 |
|
|
@@ -104,7 +120,6 @@ def run_citation_builder(papers):
|
|
| 104 |
return "\n".join(result_lines)
|
| 105 |
|
| 106 |
|
| 107 |
-
|
| 108 |
def run_confidence_scorer(synthesis):
|
| 109 |
llm = get_llm()
|
| 110 |
prompt = (
|
|
@@ -122,7 +137,7 @@ def run_confidence_scorer(synthesis):
|
|
| 122 |
"Scores: 8-10 = strong evidence, 5-7 = moderate, 1-4 = weak/preliminary.\n\n"
|
| 123 |
"Synthesis:\n" + synthesis
|
| 124 |
)
|
| 125 |
-
response =
|
| 126 |
import json
|
| 127 |
text = response.content.strip()
|
| 128 |
text = text.replace("```json", "").replace("```", "").strip()
|
|
@@ -148,7 +163,7 @@ def run_selective_review(user_question, selected_papers):
|
|
| 148 |
"Question: " + user_question + "\n\n"
|
| 149 |
"Selected Papers:\n" + corpus
|
| 150 |
)
|
| 151 |
-
response =
|
| 152 |
return response.content
|
| 153 |
|
| 154 |
|
|
@@ -163,17 +178,17 @@ def run_predictive_model(user_question, synthesis):
|
|
| 163 |
"## Destructive Forecast\n"
|
| 164 |
"2-3 sentences: Which current assumptions, treatments, or paradigms does the evidence suggest "
|
| 165 |
"may be challenged, overturned, or significantly revised in coming years?\n\n"
|
| 166 |
-
"IMPORTANT: Always produce both sections even if evidence is limited. Never ask for more input.\n"
|
|
|
|
| 167 |
"Clinical Question: " + user_question + "\n\n"
|
| 168 |
"Synthesis:\n" + synthesis
|
| 169 |
)
|
| 170 |
-
response =
|
| 171 |
return response.content
|
| 172 |
|
| 173 |
|
| 174 |
def run_table_extractor(user_question, synthesis, papers):
|
| 175 |
llm = get_llm()
|
| 176 |
-
# Build a brief paper list for context
|
| 177 |
paper_list = []
|
| 178 |
for pmid, p in list(papers.items())[:10]:
|
| 179 |
paper_list.append("PMID " + pmid + ": " + p.get("title", "N/A") + " (" + p.get("year", "") + ")")
|
|
@@ -200,11 +215,56 @@ def run_table_extractor(user_question, synthesis, papers):
|
|
| 200 |
"Papers:\n" + papers_str + "\n\n"
|
| 201 |
"Synthesis:\n" + synthesis[:1500]
|
| 202 |
)
|
| 203 |
-
response =
|
| 204 |
import json
|
| 205 |
text = response.content.strip().replace("```json", "").replace("```", "").strip()
|
| 206 |
return json.loads(text)
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
def run_pipeline(user_question):
|
| 209 |
print("[1/4] Query Architect: generating search queries...")
|
| 210 |
queries = run_query_architect(user_question)
|
|
@@ -237,4 +297,4 @@ if __name__ == "__main__":
|
|
| 237 |
print("\n=== SYNTHESIS ===")
|
| 238 |
print(result["synthesis"])
|
| 239 |
print("\n=== REFERENCES ===")
|
| 240 |
-
print(result["citations"])
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def get_llm():
|
| 12 |
+
from langchain_groq import ChatGroq
|
| 13 |
return ChatGroq(
|
| 14 |
model="llama-3.1-8b-instant",
|
| 15 |
temperature=0,
|
|
|
|
| 17 |
)
|
| 18 |
|
| 19 |
|
| 20 |
+
def llm_invoke_with_retry(llm, prompt, max_retries=5):
|
| 21 |
+
import time
|
| 22 |
+
for attempt in range(max_retries):
|
| 23 |
+
try:
|
| 24 |
+
return llm.invoke(prompt)
|
| 25 |
+
except Exception as e:
|
| 26 |
+
if "429" in str(e) or "rate_limit" in str(e).lower():
|
| 27 |
+
wait = 10 * (attempt + 1)
|
| 28 |
+
print(f"[ARIA] Rate limit hit, waiting {wait}s (attempt {attempt+1}/{max_retries})")
|
| 29 |
+
time.sleep(wait)
|
| 30 |
+
else:
|
| 31 |
+
raise
|
| 32 |
+
raise RuntimeError("Max retries exceeded on rate limit")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
@tool
|
| 36 |
def PubMedSearch(query: str) -> str:
|
| 37 |
"""Searches PubMed for biomedical literature abstracts."""
|
|
|
|
| 53 |
"Return ONLY a numbered list 1-5, one query per line, no explanations.\n\n"
|
| 54 |
"Question: " + user_question
|
| 55 |
)
|
| 56 |
+
response = llm_invoke_with_retry(llm, prompt)
|
| 57 |
raw_lines = response.content.strip().split("\n")
|
| 58 |
queries = []
|
| 59 |
for line in raw_lines:
|
|
|
|
| 103 |
"Retrieved Literature:\n" + corpus + "\n\n"
|
| 104 |
"Be precise and cite PMIDs throughout."
|
| 105 |
)
|
| 106 |
+
response = llm_invoke_with_retry(llm, prompt)
|
| 107 |
return response.content
|
| 108 |
|
| 109 |
|
|
|
|
| 120 |
return "\n".join(result_lines)
|
| 121 |
|
| 122 |
|
|
|
|
| 123 |
def run_confidence_scorer(synthesis):
|
| 124 |
llm = get_llm()
|
| 125 |
prompt = (
|
|
|
|
| 137 |
"Scores: 8-10 = strong evidence, 5-7 = moderate, 1-4 = weak/preliminary.\n\n"
|
| 138 |
"Synthesis:\n" + synthesis
|
| 139 |
)
|
| 140 |
+
response = llm_invoke_with_retry(llm, prompt)
|
| 141 |
import json
|
| 142 |
text = response.content.strip()
|
| 143 |
text = text.replace("```json", "").replace("```", "").strip()
|
|
|
|
| 163 |
"Question: " + user_question + "\n\n"
|
| 164 |
"Selected Papers:\n" + corpus
|
| 165 |
)
|
| 166 |
+
response = llm_invoke_with_retry(llm, prompt)
|
| 167 |
return response.content
|
| 168 |
|
| 169 |
|
|
|
|
| 178 |
"## Destructive Forecast\n"
|
| 179 |
"2-3 sentences: Which current assumptions, treatments, or paradigms does the evidence suggest "
|
| 180 |
"may be challenged, overturned, or significantly revised in coming years?\n\n"
|
| 181 |
+
"IMPORTANT: Always produce both sections even if evidence is limited. Never ask for more input.\n"
|
| 182 |
+
"Be specific and grounded in the evidence. No speculation beyond what the data implies.\n\n"
|
| 183 |
"Clinical Question: " + user_question + "\n\n"
|
| 184 |
"Synthesis:\n" + synthesis
|
| 185 |
)
|
| 186 |
+
response = llm_invoke_with_retry(llm, prompt)
|
| 187 |
return response.content
|
| 188 |
|
| 189 |
|
| 190 |
def run_table_extractor(user_question, synthesis, papers):
|
| 191 |
llm = get_llm()
|
|
|
|
| 192 |
paper_list = []
|
| 193 |
for pmid, p in list(papers.items())[:10]:
|
| 194 |
paper_list.append("PMID " + pmid + ": " + p.get("title", "N/A") + " (" + p.get("year", "") + ")")
|
|
|
|
| 215 |
"Papers:\n" + papers_str + "\n\n"
|
| 216 |
"Synthesis:\n" + synthesis[:1500]
|
| 217 |
)
|
| 218 |
+
response = llm_invoke_with_retry(llm, prompt)
|
| 219 |
import json
|
| 220 |
text = response.content.strip().replace("```json", "").replace("```", "").strip()
|
| 221 |
return json.loads(text)
|
| 222 |
|
| 223 |
+
|
| 224 |
+
def run_prisma_filter(user_question, papers):
|
| 225 |
+
llm = get_llm()
|
| 226 |
+
import json
|
| 227 |
+
paper_list = []
|
| 228 |
+
for pmid, p in papers.items():
|
| 229 |
+
paper_list.append(
|
| 230 |
+
"PMID " + pmid + ": " + p.get("title", "N/A") + "\n" +
|
| 231 |
+
p.get("abstract", "")[:200]
|
| 232 |
+
)
|
| 233 |
+
corpus = "\n\n".join(paper_list)
|
| 234 |
+
prompt = (
|
| 235 |
+
"You are a systematic review methodologist applying PRISMA screening criteria.\n"
|
| 236 |
+
"For each paper, decide if it should be INCLUDED or EXCLUDED for answering this clinical question.\n"
|
| 237 |
+
"Return ONLY valid JSON, no markdown, no explanation.\n"
|
| 238 |
+
"Format:\n"
|
| 239 |
+
"{\n"
|
| 240 |
+
' "decisions": [\n'
|
| 241 |
+
' {"pmid": "12345678", "decision": "included", "reason": "one sentence"},\n'
|
| 242 |
+
' {"pmid": "87654321", "decision": "excluded", "reason": "one sentence"}\n'
|
| 243 |
+
' ]\n'
|
| 244 |
+
"}\n\n"
|
| 245 |
+
"Inclusion criteria: directly relevant to the clinical question, has empirical data or clinical findings.\n"
|
| 246 |
+
"Exclusion criteria: off-topic, editorial, commentary without data, animal studies if human data exists.\n\n"
|
| 247 |
+
"Clinical Question: " + user_question + "\n\n"
|
| 248 |
+
"Papers:\n" + corpus
|
| 249 |
+
)
|
| 250 |
+
response = llm_invoke_with_retry(llm, prompt)
|
| 251 |
+
text = response.content.strip().replace("```json", "").replace("```", "").strip()
|
| 252 |
+
data = json.loads(text)
|
| 253 |
+
result = {}
|
| 254 |
+
for d in data["decisions"]:
|
| 255 |
+
pmid = d["pmid"]
|
| 256 |
+
if pmid in papers:
|
| 257 |
+
result[pmid] = {
|
| 258 |
+
**papers[pmid],
|
| 259 |
+
"included": d["decision"] == "included",
|
| 260 |
+
"reason": d["reason"]
|
| 261 |
+
}
|
| 262 |
+
for pmid in papers:
|
| 263 |
+
if pmid not in result:
|
| 264 |
+
result[pmid] = {**papers[pmid], "included": True, "reason": "Not reviewed"}
|
| 265 |
+
return result
|
| 266 |
+
|
| 267 |
+
|
| 268 |
def run_pipeline(user_question):
|
| 269 |
print("[1/4] Query Architect: generating search queries...")
|
| 270 |
queries = run_query_architect(user_question)
|
|
|
|
| 297 |
print("\n=== SYNTHESIS ===")
|
| 298 |
print(result["synthesis"])
|
| 299 |
print("\n=== REFERENCES ===")
|
| 300 |
+
print(result["citations"])
|
app.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
-
import sys, os, json
|
| 2 |
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 3 |
from flask import Flask, render_template, request, jsonify, send_file, Response, stream_with_context
|
| 4 |
-
from agent.agent import run_pipeline, run_query_architect, run_literature_scout,
|
|
|
|
| 5 |
from reportlab.lib.pagesizes import A4
|
| 6 |
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
| 7 |
from reportlab.lib.units import mm
|
|
@@ -55,22 +56,46 @@ def stream():
|
|
| 55 |
# Stage 2
|
| 56 |
yield emit("stage", {"stage": 2, "pct": 35})
|
| 57 |
papers = run_literature_scout(queries)
|
| 58 |
-
yield emit("papers", {"paper_count": len(papers), "pct":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
# Stage
|
| 61 |
-
yield emit("stage", {"stage":
|
| 62 |
-
synthesis = run_evidence_synthesiser(user_query,
|
| 63 |
yield emit("synthesis", {"synthesis": synthesis, "pct": 88})
|
| 64 |
|
| 65 |
-
# Stage
|
| 66 |
-
yield emit("stage", {"stage":
|
| 67 |
-
citations = run_citation_builder(
|
| 68 |
yield emit("done", {
|
| 69 |
"synthesis": synthesis,
|
| 70 |
"citations": citations,
|
| 71 |
-
"paper_count": len(
|
| 72 |
"queries": queries,
|
| 73 |
-
"papers": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
"pct": 100
|
| 75 |
})
|
| 76 |
|
|
@@ -175,7 +200,6 @@ def export_pdf():
|
|
| 175 |
as_attachment=True, download_name=filename)
|
| 176 |
|
| 177 |
|
| 178 |
-
|
| 179 |
@app.route("/score", methods=["POST"])
|
| 180 |
def score():
|
| 181 |
data = request.get_json()
|
|
@@ -224,22 +248,26 @@ import json as _json
|
|
| 224 |
from datetime import datetime
|
| 225 |
SESSIONS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sessions.json")
|
| 226 |
|
|
|
|
| 227 |
def load_sessions():
|
| 228 |
try:
|
| 229 |
return _json.load(open(SESSIONS_FILE))
|
| 230 |
except:
|
| 231 |
return []
|
| 232 |
|
|
|
|
| 233 |
def save_session(entry):
|
| 234 |
sessions = load_sessions()
|
| 235 |
sessions.insert(0, entry)
|
| 236 |
sessions = sessions[:20]
|
| 237 |
_json.dump(sessions, open(SESSIONS_FILE, "w"), indent=2)
|
| 238 |
|
|
|
|
| 239 |
@app.route("/sessions", methods=["GET"])
|
| 240 |
def get_sessions():
|
| 241 |
return jsonify({"sessions": load_sessions()})
|
| 242 |
|
|
|
|
| 243 |
@app.route("/sessions/save", methods=["POST"])
|
| 244 |
def save_session_route():
|
| 245 |
data = request.get_json()
|
|
@@ -273,6 +301,7 @@ def extract_table():
|
|
| 273 |
traceback.print_exc()
|
| 274 |
return jsonify({"error": str(e)}), 500
|
| 275 |
|
|
|
|
| 276 |
@app.route("/followup", methods=["POST"])
|
| 277 |
def followup():
|
| 278 |
data = request.get_json()
|
|
@@ -298,10 +327,11 @@ def followup():
|
|
| 298 |
f"Papers:\n{corpus}\n\n"
|
| 299 |
f"Follow-up Question: {question}"
|
| 300 |
)
|
| 301 |
-
response =
|
| 302 |
return jsonify({"answer": response.content})
|
| 303 |
except Exception as e:
|
| 304 |
return jsonify({"error": str(e)}), 500
|
| 305 |
|
|
|
|
| 306 |
if __name__ == "__main__":
|
| 307 |
-
app.run(debug=True, port=5000, threaded=True)
|
|
|
|
| 1 |
+
import sys, os, json, time
|
| 2 |
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 3 |
from flask import Flask, render_template, request, jsonify, send_file, Response, stream_with_context
|
| 4 |
+
from agent.agent import (run_pipeline, run_query_architect, run_literature_scout,
|
| 5 |
+
run_evidence_synthesiser, run_citation_builder, llm_invoke_with_retry)
|
| 6 |
from reportlab.lib.pagesizes import A4
|
| 7 |
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
| 8 |
from reportlab.lib.units import mm
|
|
|
|
| 56 |
# Stage 2
|
| 57 |
yield emit("stage", {"stage": 2, "pct": 35})
|
| 58 |
papers = run_literature_scout(queries)
|
| 59 |
+
yield emit("papers", {"paper_count": len(papers), "pct": 50})
|
| 60 |
+
|
| 61 |
+
# PRISMA filter
|
| 62 |
+
yield emit("stage", {"stage": 3, "pct": 55})
|
| 63 |
+
from agent.agent import run_prisma_filter
|
| 64 |
+
filtered = run_prisma_filter(user_query, papers)
|
| 65 |
+
included = {pmid: p for pmid, p in filtered.items() if p["included"]}
|
| 66 |
+
yield emit("prisma", {
|
| 67 |
+
"filtered": {
|
| 68 |
+
pmid: {"title": p.get("title", ""), "included": p["included"], "reason": p["reason"]}
|
| 69 |
+
for pmid, p in filtered.items()
|
| 70 |
+
},
|
| 71 |
+
"included_count": len(included),
|
| 72 |
+
"excluded_count": len(filtered) - len(included),
|
| 73 |
+
"pct": 65
|
| 74 |
+
})
|
| 75 |
+
time.sleep(12)
|
| 76 |
|
| 77 |
+
# Stage 4 - synthesise on included papers only
|
| 78 |
+
yield emit("stage", {"stage": 4, "pct": 70})
|
| 79 |
+
synthesis = run_evidence_synthesiser(user_query, included)
|
| 80 |
yield emit("synthesis", {"synthesis": synthesis, "pct": 88})
|
| 81 |
|
| 82 |
+
# Stage 5
|
| 83 |
+
yield emit("stage", {"stage": 5, "pct": 90})
|
| 84 |
+
citations = run_citation_builder(included)
|
| 85 |
yield emit("done", {
|
| 86 |
"synthesis": synthesis,
|
| 87 |
"citations": citations,
|
| 88 |
+
"paper_count": len(included),
|
| 89 |
"queries": queries,
|
| 90 |
+
"papers": {
|
| 91 |
+
pmid: {
|
| 92 |
+
"title": p.get("title", ""),
|
| 93 |
+
"abstract": p.get("abstract", ""),
|
| 94 |
+
"authors": p.get("authors", ""),
|
| 95 |
+
"journal": p.get("journal", ""),
|
| 96 |
+
"year": p.get("year", "")
|
| 97 |
+
} for pmid, p in included.items()
|
| 98 |
+
},
|
| 99 |
"pct": 100
|
| 100 |
})
|
| 101 |
|
|
|
|
| 200 |
as_attachment=True, download_name=filename)
|
| 201 |
|
| 202 |
|
|
|
|
| 203 |
@app.route("/score", methods=["POST"])
|
| 204 |
def score():
|
| 205 |
data = request.get_json()
|
|
|
|
| 248 |
from datetime import datetime
|
| 249 |
SESSIONS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sessions.json")
|
| 250 |
|
| 251 |
+
|
| 252 |
def load_sessions():
|
| 253 |
try:
|
| 254 |
return _json.load(open(SESSIONS_FILE))
|
| 255 |
except:
|
| 256 |
return []
|
| 257 |
|
| 258 |
+
|
| 259 |
def save_session(entry):
|
| 260 |
sessions = load_sessions()
|
| 261 |
sessions.insert(0, entry)
|
| 262 |
sessions = sessions[:20]
|
| 263 |
_json.dump(sessions, open(SESSIONS_FILE, "w"), indent=2)
|
| 264 |
|
| 265 |
+
|
| 266 |
@app.route("/sessions", methods=["GET"])
|
| 267 |
def get_sessions():
|
| 268 |
return jsonify({"sessions": load_sessions()})
|
| 269 |
|
| 270 |
+
|
| 271 |
@app.route("/sessions/save", methods=["POST"])
|
| 272 |
def save_session_route():
|
| 273 |
data = request.get_json()
|
|
|
|
| 301 |
traceback.print_exc()
|
| 302 |
return jsonify({"error": str(e)}), 500
|
| 303 |
|
| 304 |
+
|
| 305 |
@app.route("/followup", methods=["POST"])
|
| 306 |
def followup():
|
| 307 |
data = request.get_json()
|
|
|
|
| 327 |
f"Papers:\n{corpus}\n\n"
|
| 328 |
f"Follow-up Question: {question}"
|
| 329 |
)
|
| 330 |
+
response = llm_invoke_with_retry(llm, prompt)
|
| 331 |
return jsonify({"answer": response.content})
|
| 332 |
except Exception as e:
|
| 333 |
return jsonify({"error": str(e)}), 500
|
| 334 |
|
| 335 |
+
|
| 336 |
if __name__ == "__main__":
|
| 337 |
+
app.run(debug=True, port=5000, threaded=True)
|
sessions.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
templates/index.html
CHANGED
|
@@ -1058,9 +1058,9 @@ async function submitQuery() {
|
|
| 1058 |
setStage(5);
|
| 1059 |
es.close();
|
| 1060 |
renderResults(data);
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
saveSession(data, q);
|
| 1065 |
btn.disabled = false;
|
| 1066 |
});
|
|
|
|
| 1058 |
setStage(5);
|
| 1059 |
es.close();
|
| 1060 |
renderResults(data);
|
| 1061 |
+
setTimeout(() => scoreResults(data.synthesis), 2000);
|
| 1062 |
+
setTimeout(() => runPredictiveModel(lastQuery, data.synthesis), 20000);
|
| 1063 |
+
setTimeout(() => buildTable(lastQuery, data.synthesis, data.papers || {}), 38000);
|
| 1064 |
saveSession(data, q);
|
| 1065 |
btn.disabled = false;
|
| 1066 |
});
|