sriirohit3107 commited on
Commit
0224078
Β·
1 Parent(s): 68048d2

Initial Commit

Browse files
Files changed (2) hide show
  1. app.py +578 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import List, Dict, Any, Optional
4
+ from concurrent.futures import ThreadPoolExecutor
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
9
+ from Bio import Entrez
10
+ import traceback
11
+ import pandas as pd
12
+
13
+
14
+ # Cache last-created agent to avoid reloading the model on every call
15
+ _CACHED_AGENT_KEY = None
16
+ _CACHED_AGENT = None
17
+
18
+ # Also cache model/tokenizer per device to prevent repeated downloads
19
+ _MODEL_CACHE: Dict[str, Dict[str, Any]] = {}
20
+
21
+
22
+ MODEL_NAME = "hkust-nlp/WebExplorer-8B"
23
+
24
+
25
+ def _get_hf_components(device_str: str) -> Dict[str, Any]:
26
+ """Load and cache tokenizer/model for the requested device string."""
27
+ if device_str in _MODEL_CACHE:
28
+ return _MODEL_CACHE[device_str]
29
+
30
+ print(f"Loading model for device: {device_str}")
31
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
32
+
33
+ # Configure 4-bit quantization for much faster loading and inference (with safe fallback)
34
+ if torch.cuda.is_available():
35
+ try:
36
+ quantization_config = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_compute_dtype=torch.float16,
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_quant_type="nf4"
41
+ )
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ MODEL_NAME,
44
+ quantization_config=quantization_config,
45
+ device_map="auto",
46
+ trust_remote_code=True,
47
+ low_cpu_mem_usage=True,
48
+ )
49
+ except Exception as e:
50
+ print(f"4-bit load failed, falling back to standard half precision: {e}")
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ MODEL_NAME,
53
+ device_map="auto",
54
+ torch_dtype=torch.float16,
55
+ trust_remote_code=True,
56
+ low_cpu_mem_usage=True,
57
+ )
58
+ else:
59
+ # CPU fallback (slower)
60
+ model = AutoModelForCausalLM.from_pretrained(
61
+ MODEL_NAME,
62
+ device_map="auto",
63
+ torch_dtype=torch.float32,
64
+ low_cpu_mem_usage=True,
65
+ )
66
+
67
+ # Set padding token if not set
68
+ if tokenizer.pad_token is None:
69
+ tokenizer.pad_token = tokenizer.eos_token
70
+ model.config.pad_token_id = tokenizer.eos_token_id
71
+
72
+ print(f"Model loaded successfully on {device_str}")
73
+ _MODEL_CACHE[device_str] = {"tokenizer": tokenizer, "model": model}
74
+ return _MODEL_CACHE[device_str]
75
+
76
+
77
+ class LocalWebExplorerAgent:
78
+ """Optimized medical research agent with PubMed integration."""
79
+
80
+ def __init__(self, search_targets: List[str], use_cpu: bool):
81
+ self.search_targets = search_targets
82
+ self.device_str = "cpu" if use_cpu else ("cuda" if torch.cuda.is_available() else "cpu")
83
+
84
+ # Configure Entrez from environment variables if present
85
+ Entrez.email = os.getenv("ENTREZ_EMAIL", "harshini.kalvakuntla@gmail.com")
86
+ Entrez.api_key = os.getenv("ENTREZ_API_KEY","e87e8f21aeaa01cdd5690c52e8a4f5336008")
87
+
88
+ comps = _get_hf_components(self.device_str)
89
+ self.tokenizer = comps["tokenizer"]
90
+ self.model = comps["model"]
91
+
92
+ # Cache for search results to avoid redundant API calls
93
+ self.search_cache: Dict[str, List[Dict[str, str]]] = {}
94
+
95
+ def _needs_search(self, query: str) -> bool:
96
+ """Determine if external search is needed."""
97
+ lowered = query.lower()
98
+ trigger_terms = [
99
+ "treatment", "survival", "trial", "latest", "guideline",
100
+ "therapy", "diagnosis", "prognosis", "rate", "statistic",
101
+ "study", "research", "clinical", "evidence"
102
+ ]
103
+ return any(term in lowered for term in trigger_terms)
104
+
105
+ def _extract_diagnosis(self, query: str) -> str:
106
+ """Extract medical condition from query."""
107
+ query_lower = query.lower()
108
+
109
+ # Common conditions mapping
110
+ conditions = {
111
+ "lung": "lung cancer",
112
+ "pancreatic": "pancreatic cancer",
113
+ "breast": "breast cancer",
114
+ "colon": "colorectal cancer",
115
+ "prostate": "prostate cancer",
116
+ "melanoma": "melanoma",
117
+ "diabetes": "diabetes mellitus",
118
+ "heart failure": "heart failure",
119
+ "hypertension": "hypertension",
120
+ }
121
+
122
+ for key, value in conditions.items():
123
+ if key in query_lower:
124
+ return value
125
+
126
+ return "general medical condition"
127
+
128
+ def _pubmed_search(self, diagnosis: str) -> List[Dict[str, str]]:
129
+ """Search PubMed with caching."""
130
+ # Check cache first
131
+ if diagnosis in self.search_cache:
132
+ return self.search_cache[diagnosis]
133
+
134
+ if not Entrez.email or Entrez.email == "user@example.com":
135
+ # Return empty if no valid email configured
136
+ return []
137
+
138
+ try:
139
+ query = f"{diagnosis} treatment guidelines[Title/Abstract] OR {diagnosis} clinical practice[Title/Abstract]"
140
+ handle = Entrez.esearch(db="pubmed", term=query, retmax=3, sort="relevance")
141
+ record = Entrez.read(handle)
142
+ handle.close()
143
+
144
+ ids = record.get("IdList", [])
145
+ results: List[Dict[str, str]] = []
146
+
147
+ if ids:
148
+ # Fetch summaries in batch
149
+ fetch = Entrez.esummary(db="pubmed", id=",".join(ids), retmode="xml")
150
+ summary_list = Entrez.read(fetch)
151
+ fetch.close()
152
+
153
+ for summary in summary_list:
154
+ pmid = summary.get("Id", "")
155
+ title = summary.get("Title", "No title")
156
+ results.append({
157
+ "pmid": str(pmid),
158
+ "title": title,
159
+ "url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/"
160
+ })
161
+
162
+ # Cache results
163
+ self.search_cache[diagnosis] = results
164
+ return results
165
+
166
+ except Exception as e:
167
+ print(f"PubMed search error: {e}")
168
+ return []
169
+
170
+ def _fetch_abstracts(self, pmids: List[str]) -> str:
171
+ """Fetch abstracts in parallel for speed."""
172
+ if not Entrez.email or not pmids:
173
+ return ""
174
+
175
+ def fetch_single(pmid: str) -> str:
176
+ try:
177
+ fetch = Entrez.efetch(db="pubmed", id=pmid, rettype="abstract", retmode="text")
178
+ content = fetch.read()
179
+ fetch.close()
180
+
181
+ if isinstance(content, bytes):
182
+ content = content.decode('utf-8', errors='ignore')
183
+ return content
184
+ except Exception as e:
185
+ print(f"Error fetching abstract for PMID {pmid}: {e}")
186
+ return ""
187
+
188
+ # Use ThreadPoolExecutor for parallel fetching
189
+ with ThreadPoolExecutor(max_workers=3) as executor:
190
+ abstracts = list(executor.map(fetch_single, pmids))
191
+
192
+ return "\n\n".join([a for a in abstracts if a])
193
+
194
+ def _generate(self, prompt: str, max_new_tokens: int = 200) -> str:
195
+ """Optimized generation with proper settings."""
196
+ inputs = self.tokenizer(
197
+ prompt,
198
+ return_tensors="pt",
199
+ truncation=True,
200
+ max_length=1024 # Limit input length for speed
201
+ ).to(self.model.device)
202
+
203
+ with torch.inference_mode(): # Faster than torch.no_grad()
204
+ outputs = self.model.generate(
205
+ **inputs,
206
+ max_new_tokens=max_new_tokens,
207
+ do_sample=False, # Greedy decoding is fastest
208
+ num_beams=1,
209
+ pad_token_id=self.tokenizer.pad_token_id,
210
+ eos_token_id=self.tokenizer.eos_token_id,
211
+ use_cache=True, # KV cache for speed
212
+ )
213
+
214
+ # Decode only the generated tokens
215
+ generated_ids = outputs[0][inputs.input_ids.shape[1]:]
216
+ return self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
217
+
218
+ def execute_query(self, query: str, max_turns: int = 3) -> Dict[str, Any]:
219
+ """Execute a single query with optimized flow."""
220
+ turns: List[Dict[str, Any]] = []
221
+ timestamp = int(time.time())
222
+
223
+ # Extract diagnosis
224
+ diagnosis = self._extract_diagnosis(query)
225
+
226
+ # Turn 1: decision
227
+ needs_search = self._needs_search(query)
228
+ turns.append({
229
+ "turn": 1,
230
+ "action_decision": "search" if needs_search else "reason",
231
+ "tool_calls": [],
232
+ })
233
+
234
+ retrieved_docs: List[Dict[str, str]] = []
235
+ abstracts = ""
236
+
237
+ # Turn 2: search if needed
238
+ if needs_search and len(turns) < max_turns:
239
+ retrieved_docs = self._pubmed_search(diagnosis)
240
+ turns.append({
241
+ "turn": 2,
242
+ "action_decision": "search",
243
+ "tool_calls": [{
244
+ "tool": "pubmed.search",
245
+ "args": {"diagnosis": diagnosis},
246
+ "results": [f"PMID {d['pmid']}: {d['title']}" for d in retrieved_docs],
247
+ }],
248
+ })
249
+
250
+ # Fetch abstracts if we have PMIDs
251
+ if retrieved_docs:
252
+ pmids = [d["pmid"] for d in retrieved_docs]
253
+ abstracts = self._fetch_abstracts(pmids)
254
+
255
+ # Turn 3: Generate answer
256
+ prompt = self._build_prompt(query, diagnosis, abstracts)
257
+ answer_text = self._generate(prompt, max_new_tokens=200)
258
+
259
+ turns.append({
260
+ "turn": len(turns) + 1,
261
+ "action_decision": "reason",
262
+ "tool_calls": [],
263
+ "response": answer_text[:100] + "..."
264
+ })
265
+
266
+ # Add disclaimer and sources
267
+ answer_text = self._format_answer(answer_text, query, retrieved_docs)
268
+
269
+ return {
270
+ "model_loaded": True,
271
+ "final_answer": answer_text,
272
+ "turns": turns,
273
+ "total_turns": len(turns),
274
+ "timestamp": timestamp,
275
+ }
276
+
277
+ def _build_prompt(self, query: str, diagnosis: str, abstracts: str) -> str:
278
+ """Build optimized prompt."""
279
+ if abstracts:
280
+ return (
281
+ f"Answer this medical question based on the research below.\n\n"
282
+ f"Question: {query}\n\n"
283
+ f"Research on {diagnosis}:\n{abstracts[:1500]}\n\n" # Limit context
284
+ f"Provide a clear, concise summary of current treatments and outcomes."
285
+ )
286
+ else:
287
+ return (
288
+ f"Answer this medical question concisely and accurately.\n\n"
289
+ f"Question: {query}\n\n"
290
+ f"Provide evidence-based information in plain language."
291
+ )
292
+
293
+ def _format_answer(self, answer: str, query: str, docs: List[Dict[str, str]]) -> str:
294
+ """Format answer with disclaimer and sources."""
295
+ # Add medical disclaimer
296
+ medical_terms = ["cancer", "disease", "diabetes", "treatment", "diagnosis", "therapy"]
297
+ if any(term in query.lower() for term in medical_terms):
298
+ answer += "\n\n**Disclaimer:** This is educational information only. Always consult a healthcare professional for medical advice."
299
+
300
+ # Add sources
301
+ if docs:
302
+ answer += "\n\n**Sources:**\n" + "\n".join(
303
+ f"- [{d['title']}]({d['url']})" for d in docs
304
+ )
305
+
306
+ return answer
307
+
308
+ def execute_batch(self, queries: List[str], max_turns: int = 3, progress_callback=None) -> List[Dict[str, Any]]:
309
+ """Process multiple queries with progress tracking."""
310
+ results = []
311
+ total = len(queries)
312
+
313
+ for idx, query in enumerate(queries):
314
+ if progress_callback:
315
+ progress_callback((idx + 1) / total, desc=f"Processing query {idx + 1}/{total}")
316
+
317
+ try:
318
+ result = self.execute_query(query, max_turns=max_turns)
319
+ results.append(result)
320
+ except Exception as e:
321
+ print(f"Error processing query '{query}': {e}")
322
+ results.append({
323
+ "model_loaded": False,
324
+ "final_answer": f"Error: {str(e)}",
325
+ "turns": [],
326
+ "total_turns": 0,
327
+ "timestamp": int(time.time()),
328
+ "error": str(e)
329
+ })
330
+
331
+ return results
332
+
333
+
334
+ DEFAULT_TARGETS = [
335
+ 'nih.gov', 'cdc.gov', 'fda.gov', 'clinicaltrials.gov', 'medlineplus.gov',
336
+ 'who.int', 'cancerresearchuk.org', 'esmo.org', 'cancer.org', 'cancer.net',
337
+ 'mayoclinic.org', 'mdanderson.org', 'mskcc.org', 'dana-farber.org',
338
+ 'uptodate.com', 'ncbi.nlm.nih.gov', 'healthline.com',
339
+ ]
340
+
341
+
342
+ def get_agent(search_targets: List[str], use_cpu: bool) -> LocalWebExplorerAgent:
343
+ """Get or create cached agent."""
344
+ global _CACHED_AGENT_KEY, _CACHED_AGENT
345
+ key = (tuple(sorted(search_targets)), use_cpu)
346
+ if _CACHED_AGENT is not None and _CACHED_AGENT_KEY == key:
347
+ return _CACHED_AGENT
348
+ _CACHED_AGENT = LocalWebExplorerAgent(search_targets=search_targets, use_cpu=use_cpu)
349
+ _CACHED_AGENT_KEY = key
350
+ return _CACHED_AGENT
351
+
352
+
353
+ def run_query(query: str, domain_scope: str, device_choice: str, max_turns: int, fast_mode: bool, progress=gr.Progress()):
354
+ """Run a single query with progress tracking."""
355
+ if not query or not query.strip():
356
+ return "Please enter a query.", {}
357
+
358
+ progress(0, desc="Loading model...")
359
+ use_cpu = device_choice == "CPU"
360
+ targets = DEFAULT_TARGETS if domain_scope == "Medical (Trusted sources only)" else []
361
+
362
+ try:
363
+ agent = get_agent(targets, use_cpu=use_cpu)
364
+ progress(0.2, desc="Processing query...")
365
+
366
+ if fast_mode:
367
+ # Fast path: skip PubMed and generate a concise answer with fewer tokens
368
+ agent._needs_search = lambda q: False # bypass search
369
+ result = agent.execute_query(query.strip(), max_turns=1)
370
+ # Truncate final answer if too long
371
+ if result.get('final_answer'):
372
+ result['final_answer'] = result['final_answer'][:1200]
373
+ else:
374
+ result = agent.execute_query(query.strip(), max_turns=max_turns)
375
+
376
+ progress(1.0, desc="Complete!")
377
+ final_answer = result.get('final_answer', '')
378
+
379
+ mini_trace = {
380
+ 'model_loaded': result.get('model_loaded'),
381
+ 'turns': result.get('turns', []),
382
+ 'total_turns': result.get('total_turns'),
383
+ 'timestamp': result.get('timestamp'),
384
+ 'fast_mode': fast_mode,
385
+ }
386
+ return final_answer, mini_trace
387
+
388
+ except Exception as e:
389
+ tb = traceback.format_exc()
390
+ print("\n===== ERROR IN run_query =====\n", tb, "\n==============================\n")
391
+ return f"Error: {str(e)}", {"error": str(e), "traceback": tb}
392
+
393
+
394
+ def process_batch_file(file, domain_scope: str, device_choice: str, max_turns: int, progress=gr.Progress()):
395
+ """Process batch file with queries."""
396
+ if file is None:
397
+ return "Please upload a file.", None
398
+
399
+ progress(0, desc="Reading file...")
400
+
401
+ try:
402
+ # Read queries
403
+ if file.name.endswith('.csv'):
404
+ df = pd.read_csv(file.name)
405
+ if 'query' in df.columns:
406
+ queries = df['query'].tolist()
407
+ elif 'question' in df.columns:
408
+ queries = df['question'].tolist()
409
+ else:
410
+ queries = df.iloc[:, 0].tolist()
411
+ elif file.name.endswith('.txt'):
412
+ with open(file.name, 'r', encoding='utf-8') as f:
413
+ queries = [line.strip() for line in f if line.strip()]
414
+ else:
415
+ return "Please upload a CSV or TXT file.", None
416
+
417
+ if not queries:
418
+ return "No queries found in file.", None
419
+
420
+ progress(0.1, desc=f"Found {len(queries)} queries. Loading model...")
421
+
422
+ use_cpu = device_choice == "CPU"
423
+ targets = DEFAULT_TARGETS if domain_scope == "Medical (Trusted sources only)" else []
424
+ agent = get_agent(targets, use_cpu=use_cpu)
425
+
426
+ # Process batch
427
+ results = agent.execute_batch(
428
+ queries,
429
+ max_turns=max_turns,
430
+ progress_callback=lambda p, desc: progress(0.1 + p * 0.9, desc=desc)
431
+ )
432
+
433
+ # Create results dataframe
434
+ results_data = []
435
+ for query, result in zip(queries, results):
436
+ results_data.append({
437
+ 'Query': query,
438
+ 'Answer': result.get('final_answer', 'Error'),
439
+ 'Total Turns': result.get('total_turns', 0),
440
+ 'Success': result.get('model_loaded', False),
441
+ })
442
+
443
+ results_df = pd.DataFrame(results_data)
444
+
445
+ # Save results
446
+ output_path = f"batch_results_{int(time.time())}.csv"
447
+ results_df.to_csv(output_path, index=False)
448
+
449
+ progress(1.0, desc="Complete!")
450
+
451
+ success_count = sum(r.get('model_loaded', False) for r in results)
452
+ summary = (
453
+ f"βœ… Processed {len(queries)} queries\n\n"
454
+ f"πŸ“Š Success rate: {success_count}/{len(results)}\n\n"
455
+ f"πŸ’Ύ Results saved to: `{output_path}`"
456
+ )
457
+
458
+ return summary, results_df
459
+
460
+ except Exception as e:
461
+ tb = traceback.format_exc()
462
+ print("\n===== ERROR IN process_batch_file =====\n", tb, "\n==============================\n")
463
+ return f"Error processing file: {e}", None
464
+
465
+
466
+ # Gradio Interface
467
+ with gr.Blocks(title="WebExplorer-8B Medical Research", theme=gr.themes.Soft()) as demo:
468
+ gr.Markdown("""
469
+ # πŸ”¬ WebExplorer-8B Medical Research Assistant
470
+ Ask medical questions or process multiple queries in batch. Powered by AI and PubMed research.
471
+ """)
472
+
473
+ with gr.Tabs():
474
+ with gr.Tab("πŸ’¬ Single Query"):
475
+ with gr.Row():
476
+ query = gr.Textbox(
477
+ label="Medical Question",
478
+ lines=3,
479
+ placeholder="e.g., What are the treatment options for Type 2 diabetes?",
480
+ scale=4
481
+ )
482
+
483
+ with gr.Row():
484
+ domain_scope = gr.Radio(
485
+ choices=["Medical (Trusted sources only)", "All sources"],
486
+ value="Medical (Trusted sources only)",
487
+ label="Source Scope",
488
+ scale=2
489
+ )
490
+ device = gr.Radio(
491
+ choices=["GPU", "CPU"],
492
+ value="GPU",
493
+ label="Device",
494
+ scale=1
495
+ )
496
+ max_turns = gr.Slider(
497
+ minimum=1, maximum=5, value=2, step=1,
498
+ label="Max Research Depth",
499
+ scale=1
500
+ )
501
+ fast_mode = gr.Checkbox(value=True, label="Fast mode (skip PubMed, shorter answer)")
502
+
503
+ submit = gr.Button("πŸ” Research", variant="primary", size="lg")
504
+
505
+ answer = gr.Markdown(label="Answer", height=300)
506
+ trace = gr.Json(label="Execution Trace", visible=False)
507
+
508
+ gr.Markdown("### πŸ“š Example Questions")
509
+ gr.Examples(
510
+ examples=[
511
+ ["What are the survival rates for stage IV pancreatic cancer?"],
512
+ ["How is Type 2 diabetes diagnosed and treated?"],
513
+ ["What are the latest immunotherapy options for melanoma?"],
514
+ ["What are the risk factors for colorectal cancer?"],
515
+ ],
516
+ inputs=[query],
517
+ )
518
+
519
+ submit.click(
520
+ run_query,
521
+ inputs=[query, domain_scope, device, max_turns, fast_mode],
522
+ outputs=[answer, trace]
523
+ )
524
+
525
+ with gr.Tab("πŸ“Š Batch Processing"):
526
+ gr.Markdown("""
527
+ ### Process Multiple Queries
528
+ Upload a **CSV** (with 'query' column) or **TXT** file (one query per line).
529
+ """)
530
+
531
+ batch_file = gr.File(
532
+ label="Upload File",
533
+ file_types=['.csv', '.txt'],
534
+ scale=2
535
+ )
536
+
537
+ with gr.Row():
538
+ batch_domain = gr.Radio(
539
+ choices=["Medical (Trusted sources only)", "All sources"],
540
+ value="Medical (Trusted sources only)",
541
+ label="Source Scope"
542
+ )
543
+ batch_device = gr.Radio(
544
+ choices=["GPU", "CPU"],
545
+ value="GPU",
546
+ label="Device"
547
+ )
548
+ batch_turns = gr.Slider(
549
+ minimum=1, maximum=5, value=2, step=1,
550
+ label="Max Research Depth"
551
+ )
552
+
553
+ batch_submit = gr.Button("πŸš€ Process Batch", variant="primary", size="lg")
554
+
555
+ batch_status = gr.Markdown(label="Status")
556
+ batch_results = gr.Dataframe(label="Results Preview", max_height=400)
557
+
558
+ batch_submit.click(
559
+ process_batch_file,
560
+ inputs=[batch_file, batch_domain, batch_device, batch_turns],
561
+ outputs=[batch_status, batch_results]
562
+ )
563
+
564
+ gr.Markdown("""
565
+ ---
566
+ **Note:** Configure `ENTREZ_EMAIL` environment variable for PubMed access.
567
+ GPU recommended for faster processing (2-5s vs 30-60s on CPU).
568
+ """)
569
+
570
+
571
+ if __name__ == "__main__":
572
+ port = int(os.environ.get("PORT", "7860"))
573
+ demo.launch(
574
+ server_name="0.0.0.0",
575
+ server_port=port,
576
+ show_api=False,
577
+ share=False
578
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ accelerate
5
+ bitsandbytes
6
+ biopython
7
+ pandas