azlaan428 commited on
Commit
d62c791
·
1 Parent(s): 1ef27ba

feat: PRISMA filter, follow-up questions, rate limit retry, staggered post-pipeline calls

Browse files
Files changed (4) hide show
  1. agent/agent.py +70 -10
  2. app.py +44 -14
  3. sessions.json +0 -0
  4. templates/index.html +3 -3
agent/agent.py CHANGED
@@ -9,6 +9,7 @@ from retrieval.pubmed import fetch_pubmed
9
 
10
 
11
  def get_llm():
 
12
  return ChatGroq(
13
  model="llama-3.1-8b-instant",
14
  temperature=0,
@@ -16,6 +17,21 @@ def get_llm():
16
  )
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @tool
20
  def PubMedSearch(query: str) -> str:
21
  """Searches PubMed for biomedical literature abstracts."""
@@ -37,7 +53,7 @@ def run_query_architect(user_question):
37
  "Return ONLY a numbered list 1-5, one query per line, no explanations.\n\n"
38
  "Question: " + user_question
39
  )
40
- response = llm.invoke(prompt)
41
  raw_lines = response.content.strip().split("\n")
42
  queries = []
43
  for line in raw_lines:
@@ -87,7 +103,7 @@ def run_evidence_synthesiser(user_question, papers):
87
  "Retrieved Literature:\n" + corpus + "\n\n"
88
  "Be precise and cite PMIDs throughout."
89
  )
90
- response = llm.invoke(prompt)
91
  return response.content
92
 
93
 
@@ -104,7 +120,6 @@ def run_citation_builder(papers):
104
  return "\n".join(result_lines)
105
 
106
 
107
-
108
  def run_confidence_scorer(synthesis):
109
  llm = get_llm()
110
  prompt = (
@@ -122,7 +137,7 @@ def run_confidence_scorer(synthesis):
122
  "Scores: 8-10 = strong evidence, 5-7 = moderate, 1-4 = weak/preliminary.\n\n"
123
  "Synthesis:\n" + synthesis
124
  )
125
- response = llm.invoke(prompt)
126
  import json
127
  text = response.content.strip()
128
  text = text.replace("```json", "").replace("```", "").strip()
@@ -148,7 +163,7 @@ def run_selective_review(user_question, selected_papers):
148
  "Question: " + user_question + "\n\n"
149
  "Selected Papers:\n" + corpus
150
  )
151
- response = llm.invoke(prompt)
152
  return response.content
153
 
154
 
@@ -163,17 +178,17 @@ def run_predictive_model(user_question, synthesis):
163
  "## Destructive Forecast\n"
164
  "2-3 sentences: Which current assumptions, treatments, or paradigms does the evidence suggest "
165
  "may be challenged, overturned, or significantly revised in coming years?\n\n"
166
- "IMPORTANT: Always produce both sections even if evidence is limited. Never ask for more input.\n""Be specific and grounded in the evidence. No speculation beyond what the data implies.\n\n"
 
167
  "Clinical Question: " + user_question + "\n\n"
168
  "Synthesis:\n" + synthesis
169
  )
170
- response = llm.invoke(prompt)
171
  return response.content
172
 
173
 
174
  def run_table_extractor(user_question, synthesis, papers):
175
  llm = get_llm()
176
- # Build a brief paper list for context
177
  paper_list = []
178
  for pmid, p in list(papers.items())[:10]:
179
  paper_list.append("PMID " + pmid + ": " + p.get("title", "N/A") + " (" + p.get("year", "") + ")")
@@ -200,11 +215,56 @@ def run_table_extractor(user_question, synthesis, papers):
200
  "Papers:\n" + papers_str + "\n\n"
201
  "Synthesis:\n" + synthesis[:1500]
202
  )
203
- response = llm.invoke(prompt)
204
  import json
205
  text = response.content.strip().replace("```json", "").replace("```", "").strip()
206
  return json.loads(text)
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def run_pipeline(user_question):
209
  print("[1/4] Query Architect: generating search queries...")
210
  queries = run_query_architect(user_question)
@@ -237,4 +297,4 @@ if __name__ == "__main__":
237
  print("\n=== SYNTHESIS ===")
238
  print(result["synthesis"])
239
  print("\n=== REFERENCES ===")
240
- print(result["citations"])
 
9
 
10
 
11
  def get_llm():
12
+ from langchain_groq import ChatGroq
13
  return ChatGroq(
14
  model="llama-3.1-8b-instant",
15
  temperature=0,
 
17
  )
18
 
19
 
20
+ def llm_invoke_with_retry(llm, prompt, max_retries=5):
21
+ import time
22
+ for attempt in range(max_retries):
23
+ try:
24
+ return llm.invoke(prompt)
25
+ except Exception as e:
26
+ if "429" in str(e) or "rate_limit" in str(e).lower():
27
+ wait = 10 * (attempt + 1)
28
+ print(f"[ARIA] Rate limit hit, waiting {wait}s (attempt {attempt+1}/{max_retries})")
29
+ time.sleep(wait)
30
+ else:
31
+ raise
32
+ raise RuntimeError("Max retries exceeded on rate limit")
33
+
34
+
35
  @tool
36
  def PubMedSearch(query: str) -> str:
37
  """Searches PubMed for biomedical literature abstracts."""
 
53
  "Return ONLY a numbered list 1-5, one query per line, no explanations.\n\n"
54
  "Question: " + user_question
55
  )
56
+ response = llm_invoke_with_retry(llm, prompt)
57
  raw_lines = response.content.strip().split("\n")
58
  queries = []
59
  for line in raw_lines:
 
103
  "Retrieved Literature:\n" + corpus + "\n\n"
104
  "Be precise and cite PMIDs throughout."
105
  )
106
+ response = llm_invoke_with_retry(llm, prompt)
107
  return response.content
108
 
109
 
 
120
  return "\n".join(result_lines)
121
 
122
 
 
123
  def run_confidence_scorer(synthesis):
124
  llm = get_llm()
125
  prompt = (
 
137
  "Scores: 8-10 = strong evidence, 5-7 = moderate, 1-4 = weak/preliminary.\n\n"
138
  "Synthesis:\n" + synthesis
139
  )
140
+ response = llm_invoke_with_retry(llm, prompt)
141
  import json
142
  text = response.content.strip()
143
  text = text.replace("```json", "").replace("```", "").strip()
 
163
  "Question: " + user_question + "\n\n"
164
  "Selected Papers:\n" + corpus
165
  )
166
+ response = llm_invoke_with_retry(llm, prompt)
167
  return response.content
168
 
169
 
 
178
  "## Destructive Forecast\n"
179
  "2-3 sentences: Which current assumptions, treatments, or paradigms does the evidence suggest "
180
  "may be challenged, overturned, or significantly revised in coming years?\n\n"
181
+ "IMPORTANT: Always produce both sections even if evidence is limited. Never ask for more input.\n"
182
+ "Be specific and grounded in the evidence. No speculation beyond what the data implies.\n\n"
183
  "Clinical Question: " + user_question + "\n\n"
184
  "Synthesis:\n" + synthesis
185
  )
186
+ response = llm_invoke_with_retry(llm, prompt)
187
  return response.content
188
 
189
 
190
  def run_table_extractor(user_question, synthesis, papers):
191
  llm = get_llm()
 
192
  paper_list = []
193
  for pmid, p in list(papers.items())[:10]:
194
  paper_list.append("PMID " + pmid + ": " + p.get("title", "N/A") + " (" + p.get("year", "") + ")")
 
215
  "Papers:\n" + papers_str + "\n\n"
216
  "Synthesis:\n" + synthesis[:1500]
217
  )
218
+ response = llm_invoke_with_retry(llm, prompt)
219
  import json
220
  text = response.content.strip().replace("```json", "").replace("```", "").strip()
221
  return json.loads(text)
222
 
223
+
224
+ def run_prisma_filter(user_question, papers):
225
+ llm = get_llm()
226
+ import json
227
+ paper_list = []
228
+ for pmid, p in papers.items():
229
+ paper_list.append(
230
+ "PMID " + pmid + ": " + p.get("title", "N/A") + "\n" +
231
+ p.get("abstract", "")[:200]
232
+ )
233
+ corpus = "\n\n".join(paper_list)
234
+ prompt = (
235
+ "You are a systematic review methodologist applying PRISMA screening criteria.\n"
236
+ "For each paper, decide if it should be INCLUDED or EXCLUDED for answering this clinical question.\n"
237
+ "Return ONLY valid JSON, no markdown, no explanation.\n"
238
+ "Format:\n"
239
+ "{\n"
240
+ ' "decisions": [\n'
241
+ ' {"pmid": "12345678", "decision": "included", "reason": "one sentence"},\n'
242
+ ' {"pmid": "87654321", "decision": "excluded", "reason": "one sentence"}\n'
243
+ ' ]\n'
244
+ "}\n\n"
245
+ "Inclusion criteria: directly relevant to the clinical question, has empirical data or clinical findings.\n"
246
+ "Exclusion criteria: off-topic, editorial, commentary without data, animal studies if human data exists.\n\n"
247
+ "Clinical Question: " + user_question + "\n\n"
248
+ "Papers:\n" + corpus
249
+ )
250
+ response = llm_invoke_with_retry(llm, prompt)
251
+ text = response.content.strip().replace("```json", "").replace("```", "").strip()
252
+ data = json.loads(text)
253
+ result = {}
254
+ for d in data["decisions"]:
255
+ pmid = d["pmid"]
256
+ if pmid in papers:
257
+ result[pmid] = {
258
+ **papers[pmid],
259
+ "included": d["decision"] == "included",
260
+ "reason": d["reason"]
261
+ }
262
+ for pmid in papers:
263
+ if pmid not in result:
264
+ result[pmid] = {**papers[pmid], "included": True, "reason": "Not reviewed"}
265
+ return result
266
+
267
+
268
  def run_pipeline(user_question):
269
  print("[1/4] Query Architect: generating search queries...")
270
  queries = run_query_architect(user_question)
 
297
  print("\n=== SYNTHESIS ===")
298
  print(result["synthesis"])
299
  print("\n=== REFERENCES ===")
300
+ print(result["citations"])
app.py CHANGED
@@ -1,7 +1,8 @@
1
- import sys, os, json
2
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
3
  from flask import Flask, render_template, request, jsonify, send_file, Response, stream_with_context
4
- from agent.agent import run_pipeline, run_query_architect, run_literature_scout, run_evidence_synthesiser, run_citation_builder
 
5
  from reportlab.lib.pagesizes import A4
6
  from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
7
  from reportlab.lib.units import mm
@@ -55,22 +56,46 @@ def stream():
55
  # Stage 2
56
  yield emit("stage", {"stage": 2, "pct": 35})
57
  papers = run_literature_scout(queries)
58
- yield emit("papers", {"paper_count": len(papers), "pct": 55})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Stage 3
61
- yield emit("stage", {"stage": 3, "pct": 70})
62
- synthesis = run_evidence_synthesiser(user_query, papers)
63
  yield emit("synthesis", {"synthesis": synthesis, "pct": 88})
64
 
65
- # Stage 4
66
- yield emit("stage", {"stage": 4, "pct": 90})
67
- citations = run_citation_builder(papers)
68
  yield emit("done", {
69
  "synthesis": synthesis,
70
  "citations": citations,
71
- "paper_count": len(papers),
72
  "queries": queries,
73
- "papers": {pmid: {"title": p.get("title",""), "abstract": p.get("abstract",""), "authors": p.get("authors",""), "journal": p.get("journal",""), "year": p.get("year","")} for pmid, p in papers.items()},
 
 
 
 
 
 
 
 
74
  "pct": 100
75
  })
76
 
@@ -175,7 +200,6 @@ def export_pdf():
175
  as_attachment=True, download_name=filename)
176
 
177
 
178
-
179
  @app.route("/score", methods=["POST"])
180
  def score():
181
  data = request.get_json()
@@ -224,22 +248,26 @@ import json as _json
224
  from datetime import datetime
225
  SESSIONS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sessions.json")
226
 
 
227
  def load_sessions():
228
  try:
229
  return _json.load(open(SESSIONS_FILE))
230
  except:
231
  return []
232
 
 
233
  def save_session(entry):
234
  sessions = load_sessions()
235
  sessions.insert(0, entry)
236
  sessions = sessions[:20]
237
  _json.dump(sessions, open(SESSIONS_FILE, "w"), indent=2)
238
 
 
239
  @app.route("/sessions", methods=["GET"])
240
  def get_sessions():
241
  return jsonify({"sessions": load_sessions()})
242
 
 
243
  @app.route("/sessions/save", methods=["POST"])
244
  def save_session_route():
245
  data = request.get_json()
@@ -273,6 +301,7 @@ def extract_table():
273
  traceback.print_exc()
274
  return jsonify({"error": str(e)}), 500
275
 
 
276
  @app.route("/followup", methods=["POST"])
277
  def followup():
278
  data = request.get_json()
@@ -298,10 +327,11 @@ def followup():
298
  f"Papers:\n{corpus}\n\n"
299
  f"Follow-up Question: {question}"
300
  )
301
- response = llm.invoke(prompt)
302
  return jsonify({"answer": response.content})
303
  except Exception as e:
304
  return jsonify({"error": str(e)}), 500
305
 
 
306
  if __name__ == "__main__":
307
- app.run(debug=True, port=5000, threaded=True)
 
1
+ import sys, os, json, time
2
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
3
  from flask import Flask, render_template, request, jsonify, send_file, Response, stream_with_context
4
+ from agent.agent import (run_pipeline, run_query_architect, run_literature_scout,
5
+ run_evidence_synthesiser, run_citation_builder, llm_invoke_with_retry)
6
  from reportlab.lib.pagesizes import A4
7
  from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
8
  from reportlab.lib.units import mm
 
56
  # Stage 2
57
  yield emit("stage", {"stage": 2, "pct": 35})
58
  papers = run_literature_scout(queries)
59
+ yield emit("papers", {"paper_count": len(papers), "pct": 50})
60
+
61
+ # PRISMA filter
62
+ yield emit("stage", {"stage": 3, "pct": 55})
63
+ from agent.agent import run_prisma_filter
64
+ filtered = run_prisma_filter(user_query, papers)
65
+ included = {pmid: p for pmid, p in filtered.items() if p["included"]}
66
+ yield emit("prisma", {
67
+ "filtered": {
68
+ pmid: {"title": p.get("title", ""), "included": p["included"], "reason": p["reason"]}
69
+ for pmid, p in filtered.items()
70
+ },
71
+ "included_count": len(included),
72
+ "excluded_count": len(filtered) - len(included),
73
+ "pct": 65
74
+ })
75
+ time.sleep(12)
76
 
77
+ # Stage 4 - synthesise on included papers only
78
+ yield emit("stage", {"stage": 4, "pct": 70})
79
+ synthesis = run_evidence_synthesiser(user_query, included)
80
  yield emit("synthesis", {"synthesis": synthesis, "pct": 88})
81
 
82
+ # Stage 5
83
+ yield emit("stage", {"stage": 5, "pct": 90})
84
+ citations = run_citation_builder(included)
85
  yield emit("done", {
86
  "synthesis": synthesis,
87
  "citations": citations,
88
+ "paper_count": len(included),
89
  "queries": queries,
90
+ "papers": {
91
+ pmid: {
92
+ "title": p.get("title", ""),
93
+ "abstract": p.get("abstract", ""),
94
+ "authors": p.get("authors", ""),
95
+ "journal": p.get("journal", ""),
96
+ "year": p.get("year", "")
97
+ } for pmid, p in included.items()
98
+ },
99
  "pct": 100
100
  })
101
 
 
200
  as_attachment=True, download_name=filename)
201
 
202
 
 
203
  @app.route("/score", methods=["POST"])
204
  def score():
205
  data = request.get_json()
 
248
  from datetime import datetime
249
  SESSIONS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sessions.json")
250
 
251
+
252
  def load_sessions():
253
  try:
254
  return _json.load(open(SESSIONS_FILE))
255
  except:
256
  return []
257
 
258
+
259
  def save_session(entry):
260
  sessions = load_sessions()
261
  sessions.insert(0, entry)
262
  sessions = sessions[:20]
263
  _json.dump(sessions, open(SESSIONS_FILE, "w"), indent=2)
264
 
265
+
266
  @app.route("/sessions", methods=["GET"])
267
  def get_sessions():
268
  return jsonify({"sessions": load_sessions()})
269
 
270
+
271
  @app.route("/sessions/save", methods=["POST"])
272
  def save_session_route():
273
  data = request.get_json()
 
301
  traceback.print_exc()
302
  return jsonify({"error": str(e)}), 500
303
 
304
+
305
  @app.route("/followup", methods=["POST"])
306
  def followup():
307
  data = request.get_json()
 
327
  f"Papers:\n{corpus}\n\n"
328
  f"Follow-up Question: {question}"
329
  )
330
+ response = llm_invoke_with_retry(llm, prompt)
331
  return jsonify({"answer": response.content})
332
  except Exception as e:
333
  return jsonify({"error": str(e)}), 500
334
 
335
+
336
  if __name__ == "__main__":
337
+ app.run(debug=True, port=5000, threaded=True)
sessions.json CHANGED
The diff for this file is too large to render. See raw diff
 
templates/index.html CHANGED
@@ -1058,9 +1058,9 @@ async function submitQuery() {
1058
  setStage(5);
1059
  es.close();
1060
  renderResults(data);
1061
- scoreResults(data.synthesis);
1062
- setTimeout(() => runPredictiveModel(lastQuery, data.synthesis), 4000);
1063
- setTimeout(() => buildTable(lastQuery, data.synthesis, data.papers || {}), 6000);
1064
  saveSession(data, q);
1065
  btn.disabled = false;
1066
  });
 
1058
  setStage(5);
1059
  es.close();
1060
  renderResults(data);
1061
+ setTimeout(() => scoreResults(data.synthesis), 2000);
1062
+ setTimeout(() => runPredictiveModel(lastQuery, data.synthesis), 20000);
1063
+ setTimeout(() => buildTable(lastQuery, data.synthesis, data.papers || {}), 38000);
1064
  saveSession(data, q);
1065
  btn.disabled = false;
1066
  });