File size: 8,079 Bytes
5e5f986
e63ed60
5e5f986
 
88d7726
008847a
 
 
 
 
 
 
5e5f986
 
 
 
 
008847a
 
 
 
 
88d7726
008847a
 
 
 
e63ed60
88d7726
5e5f986
 
 
 
 
 
 
 
 
008847a
 
5e5f986
 
008847a
5e5f986
008847a
 
 
 
 
 
 
 
5e5f986
 
008847a
5e5f986
008847a
 
 
 
 
 
 
 
5e5f986
 
008847a
5e5f986
008847a
 
 
 
 
 
 
 
5e5f986
 
008847a
5e5f986
008847a
 
 
 
 
 
 
 
5e5f986
 
008847a
5e5f986
008847a
 
 
 
 
 
 
 
88d7726
5e5f986
 
008847a
 
88d7726
5e5f986
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
008847a
5e5f986
88d7726
008847a
5e5f986
008847a
5e5f986
008847a
 
 
 
 
 
88d7726
 
008847a
 
 
 
 
 
5e5f986
008847a
 
 
88d7726
e63ed60
5e5f986
 
 
 
 
 
 
 
 
008847a
 
 
 
 
 
e63ed60
008847a
 
 
5e5f986
 
008847a
5e5f986
 
 
 
 
 
 
 
 
 
008847a
e63ed60
 
 
008847a
 
 
5e5f986
008847a
 
 
 
 
 
 
e63ed60
88d7726
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import json
import os
from pathlib import Path

import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient

# -----------------------------------------------------------------------------
# FetchMerck AI Demo (lightweight, publishable)
#
# A small public demonstration of a Retrieval-Augmented Generation (RAG)
# pipeline for clinical decision support. If a prebuilt MedlinePlus index is
# present at data/corpus.jsonl + data/embeddings.npy it is used; otherwise the
# Space falls back to a tiny in-memory sample corpus so the demo always works.
#
# Generation is delegated to a hosted model via the Hugging Face Inference API.
#
# IMPORTANT: This is an educational prototype only. It is NOT a medical device
# and must not be used for diagnosis, treatment, or any clinical decision.
# -----------------------------------------------------------------------------

DISCLAIMER = (
    "⚠️ For demonstration and educational use only. "
    "Not a medical device. Not for diagnosis or treatment. "
    "Always consult a licensed clinician."
)

ATTRIBUTION_MEDLINE = (
    "Source content adapted from MedlinePlus (U.S. National Library of Medicine, public domain)."
)

DATA_DIR = Path("data")
CORPUS_PATH = DATA_DIR / "corpus.jsonl"
EMBED_PATH = DATA_DIR / "embeddings.npy"

# ---- Fallback in-memory sample corpus -------------------------------------
SAMPLE_CORPUS = [
    {
        "id": "sample::hypertension",
        "topic": "Hypertension",
        "section": "Hypertension",
        "url": "",
        "text": (
            "Hypertension is generally defined as a sustained systolic blood "
            "pressure of 130 mm Hg or higher or diastolic of 80 mm Hg or higher. "
            "First-line lifestyle measures include reduced sodium intake, weight "
            "loss, regular aerobic exercise, and limiting alcohol."
        ),
    },
    {
        "id": "sample::t2dm",
        "topic": "Type 2 Diabetes",
        "section": "Type 2 Diabetes",
        "url": "",
        "text": (
            "Type 2 diabetes is characterized by insulin resistance and relative "
            "insulin deficiency. Initial management commonly includes metformin "
            "alongside dietary changes and increased physical activity. Routine "
            "monitoring of HbA1c is used to guide therapy."
        ),
    },
    {
        "id": "sample::cap",
        "topic": "Community-Acquired Pneumonia",
        "section": "Community-Acquired Pneumonia",
        "url": "",
        "text": (
            "Community-acquired pneumonia in otherwise healthy outpatients is "
            "often treated empirically with a macrolide or doxycycline. Severity "
            "scores such as CURB-65 help decide between outpatient and inpatient "
            "care."
        ),
    },
    {
        "id": "sample::asthma",
        "topic": "Acute Asthma Exacerbation",
        "section": "Acute Asthma Exacerbation",
        "url": "",
        "text": (
            "Acute asthma exacerbations are typically managed with inhaled "
            "short-acting beta-agonists, systemic corticosteroids, and oxygen "
            "when saturation is low. Patients with poor response or worsening "
            "work of breathing require escalation of care."
        ),
    },
    {
        "id": "sample::ida",
        "topic": "Iron Deficiency Anemia",
        "section": "Iron Deficiency Anemia",
        "url": "",
        "text": (
            "Iron deficiency anemia commonly presents with fatigue, pallor, and "
            "a microcytic hypochromic blood picture. Workup should identify the "
            "underlying source of iron loss, especially gastrointestinal in "
            "adults."
        ),
    },
]

# ---- Embedder -------------------------------------------------------------
EMBED_MODEL = os.environ.get("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
embedder = SentenceTransformer(EMBED_MODEL)


def load_corpus():
    """Return (records, embeddings, source_label)."""
    if CORPUS_PATH.exists() and EMBED_PATH.exists():
        records = []
        with CORPUS_PATH.open("r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                records.append(json.loads(line))
        embs = np.load(EMBED_PATH).astype("float32")
        if len(records) != embs.shape[0]:
            raise RuntimeError(
                f"corpus/embedding length mismatch: {len(records)} vs {embs.shape[0]}"
            )
        # Re-normalize defensively.
        norms = np.linalg.norm(embs, axis=1, keepdims=True)
        norms[norms == 0] = 1.0
        embs = embs / norms
        return records, embs, "medlineplus"
    # Fallback
    texts = [d["text"] for d in SAMPLE_CORPUS]
    embs = embedder.encode(texts, normalize_embeddings=True).astype("float32")
    return list(SAMPLE_CORPUS), embs, "sample"


CORPUS_RECORDS, CORPUS_EMBS, CORPUS_SOURCE = load_corpus()
print(f"[startup] loaded {len(CORPUS_RECORDS)} chunks from {CORPUS_SOURCE} corpus")


def retrieve(query: str, k: int = 4):
    q = embedder.encode([query], normalize_embeddings=True)[0].astype("float32")
    sims = CORPUS_EMBS @ q
    idx = np.argsort(-sims)[:k]
    return [(CORPUS_RECORDS[i], float(sims[i])) for i in idx]


# ---- Hosted generation via HF Inference Providers ------------------------
GEN_MODEL = os.environ.get("GEN_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
HF_TOKEN = os.environ.get("HF_TOKEN")
client = InferenceClient(model=GEN_MODEL, token=HF_TOKEN)

SYSTEM_PROMPT = (
    "You are a clinical decision support assistant for educational use only. "
    "Answer ONLY using the provided context. If the context is insufficient, "
    "say so explicitly. Always remind the user this is not medical advice."
)


def build_messages(message, history, context):
    msgs = [{"role": "system", "content": SYSTEM_PROMPT}]
    for turn in history or []:
        if isinstance(turn, dict):
            msgs.append(turn)
    user_block = "Context:" + chr(10) + context + chr(10) + chr(10) + "Question: " + str(message)
    msgs.append({"role": "user", "content": user_block})
    return msgs


def respond(message, history):
    hits = retrieve(message, k=4)
    pieces = []
    for doc, _ in hits:
        label = doc.get("section") or doc.get("topic") or "?"
        url = doc.get("url") or ""
        body = doc.get("text", "")
        head = "[" + label + (" " + url if url else "") + "]"
        pieces.append(head + " " + body)
    context = (chr(10) + chr(10)).join(pieces)
    messages = build_messages(message, history, context)
    try:
        out = client.chat_completion(
            messages=messages,
            max_tokens=400,
            temperature=0.2,
        )
        answer = out.choices[0].message.content
    except Exception as e:
        answer = (
            "(Generation backend error: " + str(e) + ")"
            + chr(10) + chr(10) + "Retrieved context:" + chr(10) + context
        )
    seen = []
    for doc, _ in hits:
        label = doc.get("topic") or doc.get("section") or ""
        if label and label not in seen:
            seen.append(label)
    src_line = "Sources: " + (", ".join(seen) if seen else "(none)")
    footer = src_line + chr(10) + DISCLAIMER
    if CORPUS_SOURCE == "medlineplus":
        footer = ATTRIBUTION_MEDLINE + chr(10) + footer
    return answer + chr(10) + chr(10) + "---" + chr(10) + footer


demo = gr.ChatInterface(
    fn=respond,
    title="FetchMerck AI — Demo",
    description=(
        "Lightweight RAG demo for clinical decision support. "
        "Uses a public MedlinePlus-derived corpus when present, otherwise a small sample corpus. "
        + DISCLAIMER
    ),
    examples=[
        "What is first-line management of community-acquired pneumonia?",
        "How is iron deficiency anemia evaluated in adults?",
        "Initial steps for an acute asthma exacerbation?",
    ],
)

if __name__ == "__main__":
    demo.launch()