jeremygracey-ai's picture
Load MedlinePlus prebuilt corpus when present, fall back to sample; add MedlinePlus attribution
5e5f986 verified
import json
import os
from pathlib import Path
import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient
# -----------------------------------------------------------------------------
# FetchMerck AI Demo (lightweight, publishable)
#
# A small public demonstration of a Retrieval-Augmented Generation (RAG)
# pipeline for clinical decision support. If a prebuilt MedlinePlus index is
# present at data/corpus.jsonl + data/embeddings.npy it is used; otherwise the
# Space falls back to a tiny in-memory sample corpus so the demo always works.
#
# Generation is delegated to a hosted model via the Hugging Face Inference API.
#
# IMPORTANT: This is an educational prototype only. It is NOT a medical device
# and must not be used for diagnosis, treatment, or any clinical decision.
# -----------------------------------------------------------------------------
DISCLAIMER = (
"⚠️ For demonstration and educational use only. "
"Not a medical device. Not for diagnosis or treatment. "
"Always consult a licensed clinician."
)
ATTRIBUTION_MEDLINE = (
"Source content adapted from MedlinePlus (U.S. National Library of Medicine, public domain)."
)
DATA_DIR = Path("data")
CORPUS_PATH = DATA_DIR / "corpus.jsonl"
EMBED_PATH = DATA_DIR / "embeddings.npy"
# ---- Fallback in-memory sample corpus -------------------------------------
SAMPLE_CORPUS = [
{
"id": "sample::hypertension",
"topic": "Hypertension",
"section": "Hypertension",
"url": "",
"text": (
"Hypertension is generally defined as a sustained systolic blood "
"pressure of 130 mm Hg or higher or diastolic of 80 mm Hg or higher. "
"First-line lifestyle measures include reduced sodium intake, weight "
"loss, regular aerobic exercise, and limiting alcohol."
),
},
{
"id": "sample::t2dm",
"topic": "Type 2 Diabetes",
"section": "Type 2 Diabetes",
"url": "",
"text": (
"Type 2 diabetes is characterized by insulin resistance and relative "
"insulin deficiency. Initial management commonly includes metformin "
"alongside dietary changes and increased physical activity. Routine "
"monitoring of HbA1c is used to guide therapy."
),
},
{
"id": "sample::cap",
"topic": "Community-Acquired Pneumonia",
"section": "Community-Acquired Pneumonia",
"url": "",
"text": (
"Community-acquired pneumonia in otherwise healthy outpatients is "
"often treated empirically with a macrolide or doxycycline. Severity "
"scores such as CURB-65 help decide between outpatient and inpatient "
"care."
),
},
{
"id": "sample::asthma",
"topic": "Acute Asthma Exacerbation",
"section": "Acute Asthma Exacerbation",
"url": "",
"text": (
"Acute asthma exacerbations are typically managed with inhaled "
"short-acting beta-agonists, systemic corticosteroids, and oxygen "
"when saturation is low. Patients with poor response or worsening "
"work of breathing require escalation of care."
),
},
{
"id": "sample::ida",
"topic": "Iron Deficiency Anemia",
"section": "Iron Deficiency Anemia",
"url": "",
"text": (
"Iron deficiency anemia commonly presents with fatigue, pallor, and "
"a microcytic hypochromic blood picture. Workup should identify the "
"underlying source of iron loss, especially gastrointestinal in "
"adults."
),
},
]
# ---- Embedder -------------------------------------------------------------
EMBED_MODEL = os.environ.get("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
embedder = SentenceTransformer(EMBED_MODEL)
def load_corpus():
"""Return (records, embeddings, source_label)."""
if CORPUS_PATH.exists() and EMBED_PATH.exists():
records = []
with CORPUS_PATH.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
records.append(json.loads(line))
embs = np.load(EMBED_PATH).astype("float32")
if len(records) != embs.shape[0]:
raise RuntimeError(
f"corpus/embedding length mismatch: {len(records)} vs {embs.shape[0]}"
)
# Re-normalize defensively.
norms = np.linalg.norm(embs, axis=1, keepdims=True)
norms[norms == 0] = 1.0
embs = embs / norms
return records, embs, "medlineplus"
# Fallback
texts = [d["text"] for d in SAMPLE_CORPUS]
embs = embedder.encode(texts, normalize_embeddings=True).astype("float32")
return list(SAMPLE_CORPUS), embs, "sample"
CORPUS_RECORDS, CORPUS_EMBS, CORPUS_SOURCE = load_corpus()
print(f"[startup] loaded {len(CORPUS_RECORDS)} chunks from {CORPUS_SOURCE} corpus")
def retrieve(query: str, k: int = 4):
q = embedder.encode([query], normalize_embeddings=True)[0].astype("float32")
sims = CORPUS_EMBS @ q
idx = np.argsort(-sims)[:k]
return [(CORPUS_RECORDS[i], float(sims[i])) for i in idx]
# ---- Hosted generation via HF Inference Providers ------------------------
GEN_MODEL = os.environ.get("GEN_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
HF_TOKEN = os.environ.get("HF_TOKEN")
client = InferenceClient(model=GEN_MODEL, token=HF_TOKEN)
SYSTEM_PROMPT = (
"You are a clinical decision support assistant for educational use only. "
"Answer ONLY using the provided context. If the context is insufficient, "
"say so explicitly. Always remind the user this is not medical advice."
)
def build_messages(message, history, context):
msgs = [{"role": "system", "content": SYSTEM_PROMPT}]
for turn in history or []:
if isinstance(turn, dict):
msgs.append(turn)
user_block = "Context:" + chr(10) + context + chr(10) + chr(10) + "Question: " + str(message)
msgs.append({"role": "user", "content": user_block})
return msgs
def respond(message, history):
hits = retrieve(message, k=4)
pieces = []
for doc, _ in hits:
label = doc.get("section") or doc.get("topic") or "?"
url = doc.get("url") or ""
body = doc.get("text", "")
head = "[" + label + (" " + url if url else "") + "]"
pieces.append(head + " " + body)
context = (chr(10) + chr(10)).join(pieces)
messages = build_messages(message, history, context)
try:
out = client.chat_completion(
messages=messages,
max_tokens=400,
temperature=0.2,
)
answer = out.choices[0].message.content
except Exception as e:
answer = (
"(Generation backend error: " + str(e) + ")"
+ chr(10) + chr(10) + "Retrieved context:" + chr(10) + context
)
seen = []
for doc, _ in hits:
label = doc.get("topic") or doc.get("section") or ""
if label and label not in seen:
seen.append(label)
src_line = "Sources: " + (", ".join(seen) if seen else "(none)")
footer = src_line + chr(10) + DISCLAIMER
if CORPUS_SOURCE == "medlineplus":
footer = ATTRIBUTION_MEDLINE + chr(10) + footer
return answer + chr(10) + chr(10) + "---" + chr(10) + footer
demo = gr.ChatInterface(
fn=respond,
title="FetchMerck AI — Demo",
description=(
"Lightweight RAG demo for clinical decision support. "
"Uses a public MedlinePlus-derived corpus when present, otherwise a small sample corpus. "
+ DISCLAIMER
),
examples=[
"What is first-line management of community-acquired pneumonia?",
"How is iron deficiency anemia evaluated in adults?",
"Initial steps for an acute asthma exacerbation?",
],
)
if __name__ == "__main__":
demo.launch()