Spaces:
Sleeping
Sleeping
Load MedlinePlus prebuilt corpus when present, fall back to sample; add MedlinePlus attribution
5e5f986 verified | import json | |
| import os | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import InferenceClient | |
| # ----------------------------------------------------------------------------- | |
| # FetchMerck AI Demo (lightweight, publishable) | |
| # | |
| # A small public demonstration of a Retrieval-Augmented Generation (RAG) | |
| # pipeline for clinical decision support. If a prebuilt MedlinePlus index is | |
| # present at data/corpus.jsonl + data/embeddings.npy it is used; otherwise the | |
| # Space falls back to a tiny in-memory sample corpus so the demo always works. | |
| # | |
| # Generation is delegated to a hosted model via the Hugging Face Inference API. | |
| # | |
| # IMPORTANT: This is an educational prototype only. It is NOT a medical device | |
| # and must not be used for diagnosis, treatment, or any clinical decision. | |
| # ----------------------------------------------------------------------------- | |
| DISCLAIMER = ( | |
| "⚠️ For demonstration and educational use only. " | |
| "Not a medical device. Not for diagnosis or treatment. " | |
| "Always consult a licensed clinician." | |
| ) | |
| ATTRIBUTION_MEDLINE = ( | |
| "Source content adapted from MedlinePlus (U.S. National Library of Medicine, public domain)." | |
| ) | |
| DATA_DIR = Path("data") | |
| CORPUS_PATH = DATA_DIR / "corpus.jsonl" | |
| EMBED_PATH = DATA_DIR / "embeddings.npy" | |
| # ---- Fallback in-memory sample corpus ------------------------------------- | |
| SAMPLE_CORPUS = [ | |
| { | |
| "id": "sample::hypertension", | |
| "topic": "Hypertension", | |
| "section": "Hypertension", | |
| "url": "", | |
| "text": ( | |
| "Hypertension is generally defined as a sustained systolic blood " | |
| "pressure of 130 mm Hg or higher or diastolic of 80 mm Hg or higher. " | |
| "First-line lifestyle measures include reduced sodium intake, weight " | |
| "loss, regular aerobic exercise, and limiting alcohol." | |
| ), | |
| }, | |
| { | |
| "id": "sample::t2dm", | |
| "topic": "Type 2 Diabetes", | |
| "section": "Type 2 Diabetes", | |
| "url": "", | |
| "text": ( | |
| "Type 2 diabetes is characterized by insulin resistance and relative " | |
| "insulin deficiency. Initial management commonly includes metformin " | |
| "alongside dietary changes and increased physical activity. Routine " | |
| "monitoring of HbA1c is used to guide therapy." | |
| ), | |
| }, | |
| { | |
| "id": "sample::cap", | |
| "topic": "Community-Acquired Pneumonia", | |
| "section": "Community-Acquired Pneumonia", | |
| "url": "", | |
| "text": ( | |
| "Community-acquired pneumonia in otherwise healthy outpatients is " | |
| "often treated empirically with a macrolide or doxycycline. Severity " | |
| "scores such as CURB-65 help decide between outpatient and inpatient " | |
| "care." | |
| ), | |
| }, | |
| { | |
| "id": "sample::asthma", | |
| "topic": "Acute Asthma Exacerbation", | |
| "section": "Acute Asthma Exacerbation", | |
| "url": "", | |
| "text": ( | |
| "Acute asthma exacerbations are typically managed with inhaled " | |
| "short-acting beta-agonists, systemic corticosteroids, and oxygen " | |
| "when saturation is low. Patients with poor response or worsening " | |
| "work of breathing require escalation of care." | |
| ), | |
| }, | |
| { | |
| "id": "sample::ida", | |
| "topic": "Iron Deficiency Anemia", | |
| "section": "Iron Deficiency Anemia", | |
| "url": "", | |
| "text": ( | |
| "Iron deficiency anemia commonly presents with fatigue, pallor, and " | |
| "a microcytic hypochromic blood picture. Workup should identify the " | |
| "underlying source of iron loss, especially gastrointestinal in " | |
| "adults." | |
| ), | |
| }, | |
| ] | |
| # ---- Embedder ------------------------------------------------------------- | |
| EMBED_MODEL = os.environ.get("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| embedder = SentenceTransformer(EMBED_MODEL) | |
| def load_corpus(): | |
| """Return (records, embeddings, source_label).""" | |
| if CORPUS_PATH.exists() and EMBED_PATH.exists(): | |
| records = [] | |
| with CORPUS_PATH.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| records.append(json.loads(line)) | |
| embs = np.load(EMBED_PATH).astype("float32") | |
| if len(records) != embs.shape[0]: | |
| raise RuntimeError( | |
| f"corpus/embedding length mismatch: {len(records)} vs {embs.shape[0]}" | |
| ) | |
| # Re-normalize defensively. | |
| norms = np.linalg.norm(embs, axis=1, keepdims=True) | |
| norms[norms == 0] = 1.0 | |
| embs = embs / norms | |
| return records, embs, "medlineplus" | |
| # Fallback | |
| texts = [d["text"] for d in SAMPLE_CORPUS] | |
| embs = embedder.encode(texts, normalize_embeddings=True).astype("float32") | |
| return list(SAMPLE_CORPUS), embs, "sample" | |
| CORPUS_RECORDS, CORPUS_EMBS, CORPUS_SOURCE = load_corpus() | |
| print(f"[startup] loaded {len(CORPUS_RECORDS)} chunks from {CORPUS_SOURCE} corpus") | |
| def retrieve(query: str, k: int = 4): | |
| q = embedder.encode([query], normalize_embeddings=True)[0].astype("float32") | |
| sims = CORPUS_EMBS @ q | |
| idx = np.argsort(-sims)[:k] | |
| return [(CORPUS_RECORDS[i], float(sims[i])) for i in idx] | |
| # ---- Hosted generation via HF Inference Providers ------------------------ | |
| GEN_MODEL = os.environ.get("GEN_MODEL", "meta-llama/Llama-3.1-8B-Instruct") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| client = InferenceClient(model=GEN_MODEL, token=HF_TOKEN) | |
| SYSTEM_PROMPT = ( | |
| "You are a clinical decision support assistant for educational use only. " | |
| "Answer ONLY using the provided context. If the context is insufficient, " | |
| "say so explicitly. Always remind the user this is not medical advice." | |
| ) | |
| def build_messages(message, history, context): | |
| msgs = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| for turn in history or []: | |
| if isinstance(turn, dict): | |
| msgs.append(turn) | |
| user_block = "Context:" + chr(10) + context + chr(10) + chr(10) + "Question: " + str(message) | |
| msgs.append({"role": "user", "content": user_block}) | |
| return msgs | |
| def respond(message, history): | |
| hits = retrieve(message, k=4) | |
| pieces = [] | |
| for doc, _ in hits: | |
| label = doc.get("section") or doc.get("topic") or "?" | |
| url = doc.get("url") or "" | |
| body = doc.get("text", "") | |
| head = "[" + label + (" " + url if url else "") + "]" | |
| pieces.append(head + " " + body) | |
| context = (chr(10) + chr(10)).join(pieces) | |
| messages = build_messages(message, history, context) | |
| try: | |
| out = client.chat_completion( | |
| messages=messages, | |
| max_tokens=400, | |
| temperature=0.2, | |
| ) | |
| answer = out.choices[0].message.content | |
| except Exception as e: | |
| answer = ( | |
| "(Generation backend error: " + str(e) + ")" | |
| + chr(10) + chr(10) + "Retrieved context:" + chr(10) + context | |
| ) | |
| seen = [] | |
| for doc, _ in hits: | |
| label = doc.get("topic") or doc.get("section") or "" | |
| if label and label not in seen: | |
| seen.append(label) | |
| src_line = "Sources: " + (", ".join(seen) if seen else "(none)") | |
| footer = src_line + chr(10) + DISCLAIMER | |
| if CORPUS_SOURCE == "medlineplus": | |
| footer = ATTRIBUTION_MEDLINE + chr(10) + footer | |
| return answer + chr(10) + chr(10) + "---" + chr(10) + footer | |
| demo = gr.ChatInterface( | |
| fn=respond, | |
| title="FetchMerck AI — Demo", | |
| description=( | |
| "Lightweight RAG demo for clinical decision support. " | |
| "Uses a public MedlinePlus-derived corpus when present, otherwise a small sample corpus. " | |
| + DISCLAIMER | |
| ), | |
| examples=[ | |
| "What is first-line management of community-acquired pneumonia?", | |
| "How is iron deficiency anemia evaluated in adults?", | |
| "Initial steps for an acute asthma exacerbation?", | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |