import json import os from pathlib import Path import gradio as gr import numpy as np from sentence_transformers import SentenceTransformer from huggingface_hub import InferenceClient # ----------------------------------------------------------------------------- # FetchMerck AI Demo (lightweight, publishable) # # A small public demonstration of a Retrieval-Augmented Generation (RAG) # pipeline for clinical decision support. If a prebuilt MedlinePlus index is # present at data/corpus.jsonl + data/embeddings.npy it is used; otherwise the # Space falls back to a tiny in-memory sample corpus so the demo always works. # # Generation is delegated to a hosted model via the Hugging Face Inference API. # # IMPORTANT: This is an educational prototype only. It is NOT a medical device # and must not be used for diagnosis, treatment, or any clinical decision. # ----------------------------------------------------------------------------- DISCLAIMER = ( "⚠️ For demonstration and educational use only. " "Not a medical device. Not for diagnosis or treatment. " "Always consult a licensed clinician." ) ATTRIBUTION_MEDLINE = ( "Source content adapted from MedlinePlus (U.S. National Library of Medicine, public domain)." ) DATA_DIR = Path("data") CORPUS_PATH = DATA_DIR / "corpus.jsonl" EMBED_PATH = DATA_DIR / "embeddings.npy" # ---- Fallback in-memory sample corpus ------------------------------------- SAMPLE_CORPUS = [ { "id": "sample::hypertension", "topic": "Hypertension", "section": "Hypertension", "url": "", "text": ( "Hypertension is generally defined as a sustained systolic blood " "pressure of 130 mm Hg or higher or diastolic of 80 mm Hg or higher. " "First-line lifestyle measures include reduced sodium intake, weight " "loss, regular aerobic exercise, and limiting alcohol." ), }, { "id": "sample::t2dm", "topic": "Type 2 Diabetes", "section": "Type 2 Diabetes", "url": "", "text": ( "Type 2 diabetes is characterized by insulin resistance and relative " "insulin deficiency. Initial management commonly includes metformin " "alongside dietary changes and increased physical activity. Routine " "monitoring of HbA1c is used to guide therapy." ), }, { "id": "sample::cap", "topic": "Community-Acquired Pneumonia", "section": "Community-Acquired Pneumonia", "url": "", "text": ( "Community-acquired pneumonia in otherwise healthy outpatients is " "often treated empirically with a macrolide or doxycycline. Severity " "scores such as CURB-65 help decide between outpatient and inpatient " "care." ), }, { "id": "sample::asthma", "topic": "Acute Asthma Exacerbation", "section": "Acute Asthma Exacerbation", "url": "", "text": ( "Acute asthma exacerbations are typically managed with inhaled " "short-acting beta-agonists, systemic corticosteroids, and oxygen " "when saturation is low. Patients with poor response or worsening " "work of breathing require escalation of care." ), }, { "id": "sample::ida", "topic": "Iron Deficiency Anemia", "section": "Iron Deficiency Anemia", "url": "", "text": ( "Iron deficiency anemia commonly presents with fatigue, pallor, and " "a microcytic hypochromic blood picture. Workup should identify the " "underlying source of iron loss, especially gastrointestinal in " "adults." ), }, ] # ---- Embedder ------------------------------------------------------------- EMBED_MODEL = os.environ.get("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2") embedder = SentenceTransformer(EMBED_MODEL) def load_corpus(): """Return (records, embeddings, source_label).""" if CORPUS_PATH.exists() and EMBED_PATH.exists(): records = [] with CORPUS_PATH.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue records.append(json.loads(line)) embs = np.load(EMBED_PATH).astype("float32") if len(records) != embs.shape[0]: raise RuntimeError( f"corpus/embedding length mismatch: {len(records)} vs {embs.shape[0]}" ) # Re-normalize defensively. norms = np.linalg.norm(embs, axis=1, keepdims=True) norms[norms == 0] = 1.0 embs = embs / norms return records, embs, "medlineplus" # Fallback texts = [d["text"] for d in SAMPLE_CORPUS] embs = embedder.encode(texts, normalize_embeddings=True).astype("float32") return list(SAMPLE_CORPUS), embs, "sample" CORPUS_RECORDS, CORPUS_EMBS, CORPUS_SOURCE = load_corpus() print(f"[startup] loaded {len(CORPUS_RECORDS)} chunks from {CORPUS_SOURCE} corpus") def retrieve(query: str, k: int = 4): q = embedder.encode([query], normalize_embeddings=True)[0].astype("float32") sims = CORPUS_EMBS @ q idx = np.argsort(-sims)[:k] return [(CORPUS_RECORDS[i], float(sims[i])) for i in idx] # ---- Hosted generation via HF Inference Providers ------------------------ GEN_MODEL = os.environ.get("GEN_MODEL", "meta-llama/Llama-3.1-8B-Instruct") HF_TOKEN = os.environ.get("HF_TOKEN") client = InferenceClient(model=GEN_MODEL, token=HF_TOKEN) SYSTEM_PROMPT = ( "You are a clinical decision support assistant for educational use only. " "Answer ONLY using the provided context. If the context is insufficient, " "say so explicitly. Always remind the user this is not medical advice." ) def build_messages(message, history, context): msgs = [{"role": "system", "content": SYSTEM_PROMPT}] for turn in history or []: if isinstance(turn, dict): msgs.append(turn) user_block = "Context:" + chr(10) + context + chr(10) + chr(10) + "Question: " + str(message) msgs.append({"role": "user", "content": user_block}) return msgs def respond(message, history): hits = retrieve(message, k=4) pieces = [] for doc, _ in hits: label = doc.get("section") or doc.get("topic") or "?" url = doc.get("url") or "" body = doc.get("text", "") head = "[" + label + (" " + url if url else "") + "]" pieces.append(head + " " + body) context = (chr(10) + chr(10)).join(pieces) messages = build_messages(message, history, context) try: out = client.chat_completion( messages=messages, max_tokens=400, temperature=0.2, ) answer = out.choices[0].message.content except Exception as e: answer = ( "(Generation backend error: " + str(e) + ")" + chr(10) + chr(10) + "Retrieved context:" + chr(10) + context ) seen = [] for doc, _ in hits: label = doc.get("topic") or doc.get("section") or "" if label and label not in seen: seen.append(label) src_line = "Sources: " + (", ".join(seen) if seen else "(none)") footer = src_line + chr(10) + DISCLAIMER if CORPUS_SOURCE == "medlineplus": footer = ATTRIBUTION_MEDLINE + chr(10) + footer return answer + chr(10) + chr(10) + "---" + chr(10) + footer demo = gr.ChatInterface( fn=respond, title="FetchMerck AI — Demo", description=( "Lightweight RAG demo for clinical decision support. " "Uses a public MedlinePlus-derived corpus when present, otherwise a small sample corpus. " + DISCLAIMER ), examples=[ "What is first-line management of community-acquired pneumonia?", "How is iron deficiency anemia evaluated in adults?", "Initial steps for an acute asthma exacerbation?", ], ) if __name__ == "__main__": demo.launch()