Aswin92 commited on
Commit
89c6379
·
verified ·
1 Parent(s): b4fa60b

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .claude/settings.local.json +26 -0
  2. .dockerignore +27 -0
  3. .env +16 -0
  4. .gitattributes +4 -0
  5. Dockerfile +81 -0
  6. README.md +44 -5
  7. backend/Dockerfile +21 -0
  8. backend/api/__init__.py +1 -0
  9. backend/api/__pycache__/__init__.cpython-310.pyc +0 -0
  10. backend/api/__pycache__/hallucination_guard.cpython-310.pyc +0 -0
  11. backend/api/__pycache__/main.cpython-310.pyc +0 -0
  12. backend/api/__pycache__/models.cpython-310.pyc +0 -0
  13. backend/api/__pycache__/pipeline.cpython-310.pyc +0 -0
  14. backend/api/hallucination_guard.py +224 -0
  15. backend/api/main.py +141 -0
  16. backend/api/models.py +57 -0
  17. backend/api/pipeline.py +232 -0
  18. backend/dashboard/__init__.py +1 -0
  19. backend/dashboard/__pycache__/charts.cpython-310.pyc +0 -0
  20. backend/dashboard/app.py +472 -0
  21. backend/dashboard/charts.py +269 -0
  22. backend/reports/week4_evaluation.md +131 -0
  23. backend/requirements.txt +29 -0
  24. backend/scripts/__pycache__/graph_store.cpython-310.pyc +0 -0
  25. backend/scripts/__pycache__/symptom_parser.cpython-310.pyc +0 -0
  26. backend/scripts/download_hpo.py +75 -0
  27. backend/scripts/download_orphanet.py +232 -0
  28. backend/scripts/embed_chromadb.py +208 -0
  29. backend/scripts/graph_store.py +300 -0
  30. backend/scripts/hello_world.py +257 -0
  31. backend/scripts/ingest_hpo.py +198 -0
  32. backend/scripts/ingest_neo4j.py +192 -0
  33. backend/scripts/milestone_2a.py +344 -0
  34. backend/scripts/milestone_2b.py +185 -0
  35. backend/scripts/reembed_chromadb.py +224 -0
  36. backend/scripts/symptom_parser.py +245 -0
  37. backend/scripts/test_week3p1.py +161 -0
  38. backend/scripts/week4_evaluation.py +612 -0
  39. data/chromadb/7ea50702-c46b-42f7-b973-7759bdb87d47/data_level0.bin +3 -0
  40. data/chromadb/7ea50702-c46b-42f7-b973-7759bdb87d47/header.bin +3 -0
  41. data/chromadb/7ea50702-c46b-42f7-b973-7759bdb87d47/index_metadata.pickle +3 -0
  42. data/chromadb/7ea50702-c46b-42f7-b973-7759bdb87d47/length.bin +3 -0
  43. data/chromadb/7ea50702-c46b-42f7-b973-7759bdb87d47/link_lists.bin +3 -0
  44. data/chromadb/a9c34cfc-1758-49de-88aa-b1701299ecca/data_level0.bin +3 -0
  45. data/chromadb/a9c34cfc-1758-49de-88aa-b1701299ecca/header.bin +3 -0
  46. data/chromadb/a9c34cfc-1758-49de-88aa-b1701299ecca/index_metadata.pickle +3 -0
  47. data/chromadb/a9c34cfc-1758-49de-88aa-b1701299ecca/length.bin +3 -0
  48. data/chromadb/a9c34cfc-1758-49de-88aa-b1701299ecca/link_lists.bin +3 -0
  49. data/chromadb/chroma.sqlite3 +3 -0
  50. data/graph_store.json +3 -0
.claude/settings.local.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(pip install:*)",
5
+ "Bash(docker compose:*)",
6
+ "Bash(where docker:*)",
7
+ "Read(//c/Program Files/Docker/Docker/resources/**)",
8
+ "Read(//c/Program Files/**)",
9
+ "Read(//c/Users/Aswin/AppData/Local/**)",
10
+ "Bash(cmd.exe /c \"where docker 2>&1\")",
11
+ "Bash(cmd.exe /c \"docker --version 2>&1 && docker compose version 2>&1\")",
12
+ "Bash(python:*)",
13
+ "Bash(curl -s http://localhost:8080/health)",
14
+ "Bash(curl -s -X POST http://localhost:8080/diagnose -H \"Content-Type: application/json\" -d \"{\"\"note\"\": \"\"18 year old male, extremely tall, displaced lens in left eye, heart murmur, flexible joints, scoliosis\"\", \"\"top_n\"\": 10}\")",
15
+ "Bash(kill 1981)",
16
+ "Bash(wait)",
17
+ "Bash(curl -s -X POST http://localhost:8080/diagnose -H \"Content-Type: application/json\" -d \"{\"\"note\"\": \"\"18 year old male, extremely tall, displaced lens in left eye, heart murmur, flexible joints, scoliosis\"\", \"\"top_n\"\": 15, \"\"threshold\"\": 0.52}\")",
18
+ "Bash(kill 2076 2085)",
19
+ "Bash(python -c \":*)",
20
+ "WebFetch(domain:api.github.com)",
21
+ "WebFetch(domain:raw.githubusercontent.com)",
22
+ "WebSearch",
23
+ "WebFetch(domain:github.com)"
24
+ ]
25
+ }
26
+ }
.dockerignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Version control
2
+ .git
3
+ .gitignore
4
+
5
+ # Python cache
6
+ **/__pycache__
7
+ **/*.pyc
8
+ **/*.pyo
9
+ **/*.pyd
10
+ **/.pytest_cache
11
+
12
+ # Orphanet raw XML — large, only needed to regenerate data, not at runtime
13
+ data/orphanet/
14
+
15
+ # Reports — not needed in container
16
+ backend/reports/
17
+
18
+ # Local dev files
19
+ .env
20
+ docker-compose.yml
21
+ backend/Dockerfile
22
+
23
+ # Editor / OS
24
+ .vscode
25
+ .idea
26
+ *.DS_Store
27
+ Thumbs.db
.env ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Neo4j
2
+ NEO4J_URI=bolt://localhost:7687
3
+ NEO4J_USER=neo4j
4
+ NEO4J_PASSWORD=raredx_password
5
+
6
+ # ChromaDB
7
+ CHROMA_HOST=localhost
8
+ CHROMA_PORT=8000
9
+
10
+ # Data paths
11
+ ORPHANET_DATA_DIR=./data/orphanet
12
+ ORPHANET_XML=./data/orphanet/en_product1.xml
13
+
14
+ # BioLORD model
15
+ EMBED_MODEL=FremyCompany/BioLORD-2023
16
+ CHROMA_COLLECTION=rare_diseases
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/chromadb/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
37
+ data/graph_store.json filter=lfs diff=lfs merge=lfs -text
38
+ data/orphanet/en_product1.xml filter=lfs diff=lfs merge=lfs -text
39
+ data/orphanet/en_product4.xml filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # RareDx — Hugging Face Spaces Dockerfile
3
+ # Single container: FastAPI (8080, internal) + Streamlit (8501, public)
4
+ # =============================================================================
5
+
6
+ FROM python:3.11-slim
7
+
8
+ # --------------------------------------------------------------------------
9
+ # System dependencies
10
+ # --------------------------------------------------------------------------
11
+ RUN apt-get update && apt-get install -y --no-install-recommends \
12
+ gcc \
13
+ g++ \
14
+ libxml2-dev \
15
+ libxslt-dev \
16
+ curl \
17
+ supervisor \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ WORKDIR /app
21
+
22
+ # --------------------------------------------------------------------------
23
+ # Python dependencies
24
+ # Install before copying source so this layer is cached on code-only changes
25
+ # --------------------------------------------------------------------------
26
+ COPY backend/requirements.txt ./requirements.txt
27
+ RUN pip install --no-cache-dir -r requirements.txt
28
+
29
+ # --------------------------------------------------------------------------
30
+ # Pre-download BioLORD-2023 model into the image
31
+ # This avoids a ~500MB download on every Space restart
32
+ # --------------------------------------------------------------------------
33
+ ENV HF_HOME=/app/.cache/huggingface
34
+ RUN python -c "\
35
+ from sentence_transformers import SentenceTransformer; \
36
+ print('Downloading BioLORD-2023...'); \
37
+ SentenceTransformer('FremyCompany/BioLORD-2023'); \
38
+ print('Model cached.')"
39
+
40
+ # --------------------------------------------------------------------------
41
+ # Application source
42
+ # --------------------------------------------------------------------------
43
+ COPY backend/ ./backend/
44
+
45
+ # --------------------------------------------------------------------------
46
+ # Pre-built knowledge data (bundled — no runtime download needed)
47
+ # data/graph_store.json — 33MB Orphanet+HPO knowledge graph (NetworkX JSON)
48
+ # data/chromadb/ — 149MB BioLORD disease embeddings (ChromaDB)
49
+ # data/hpo_index/ — 26MB BioLORD HPO term embeddings (numpy + JSON)
50
+ # --------------------------------------------------------------------------
51
+ COPY data/graph_store.json ./data/graph_store.json
52
+ COPY data/chromadb/ ./data/chromadb/
53
+ COPY data/hpo_index/ ./data/hpo_index/
54
+
55
+ # --------------------------------------------------------------------------
56
+ # supervisord config
57
+ # --------------------------------------------------------------------------
58
+ COPY supervisord.conf /etc/supervisor/conf.d/raredx.conf
59
+
60
+ # --------------------------------------------------------------------------
61
+ # Runtime environment
62
+ # Tell pipeline to use embedded ChromaDB and local graph store
63
+ # (no Neo4j or external ChromaDB server needed)
64
+ # --------------------------------------------------------------------------
65
+ ENV PYTHONUNBUFFERED=1 \
66
+ PYTHONDONTWRITEBYTECODE=1 \
67
+ CHROMA_HOST=localhost \
68
+ CHROMA_PORT=9999 \
69
+ CHROMA_COLLECTION=rare_diseases \
70
+ EMBED_MODEL=FremyCompany/BioLORD-2023 \
71
+ ORPHANET_DATA_DIR=/app/data/orphanet
72
+
73
+ # Port Streamlit listens on (declared for HF Spaces)
74
+ EXPOSE 8501
75
+
76
+ # --------------------------------------------------------------------------
77
+ # Start both services via supervisord
78
+ # FastAPI: 127.0.0.1:8080 (internal — Streamlit calls it)
79
+ # Streamlit: 0.0.0.0:8501 (public — HF Spaces exposes this)
80
+ # --------------------------------------------------------------------------
81
+ CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/raredx.conf"]
README.md CHANGED
@@ -1,10 +1,49 @@
1
  ---
2
- title: Raredx
3
- emoji: 📉
4
- colorFrom: gray
5
- colorTo: blue
6
  sdk: docker
 
7
  pinned: false
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: RareDx — Rare Disease Diagnostic AI
3
+ emoji: 🧬
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: docker
7
+ app_port: 8501
8
  pinned: false
9
+ short_description: Multi-agent AI for rare disease diagnosis
10
  ---
11
 
12
+ # RareDx Rare Disease Diagnostic AI
13
+
14
+ A multi-agent clinical AI system that generates differential diagnoses for rare diseases from plain-text clinical notes.
15
+
16
+ ## How It Works
17
+
18
+ 1. **Symptom Parser** — maps free-text symptoms to HPO term IDs using BioLORD-2023 semantic similarity
19
+ 2. **Graph Search** — traverses the Orphanet/HPO knowledge graph (11,456 diseases, 115,839 phenotype associations)
20
+ 3. **Vector Search** — BioLORD semantic search over HPO-enriched disease embeddings
21
+ 4. **RRF Fusion** — merges both rankings via Reciprocal Rank Fusion
22
+ 5. **Hallucination Guard** — FusionNode filters candidates lacking phenotype evidence
23
+
24
+ ## Example
25
+
26
+ Paste a clinical note like:
27
+ > *"18 year old male, extremely tall, displaced lens in left eye, heart murmur, flexible joints, scoliosis"*
28
+
29
+ The system returns ranked differential diagnoses with evidence scores, matched HPO terms, and an interactive evidence map.
30
+
31
+ ## Architecture
32
+
33
+ | Component | Technology |
34
+ |-----------|-----------|
35
+ | Knowledge graph | Orphanet + HPO (NetworkX JSON) |
36
+ | Embeddings | FremyCompany/BioLORD-2023 (768-dim) |
37
+ | Vector store | ChromaDB (embedded) |
38
+ | API | FastAPI on port 8080 (internal) |
39
+ | Dashboard | Streamlit on port 8501 |
40
+
41
+ ## Data Sources
42
+
43
+ - [Orphanet](https://www.orphadata.com/) — rare disease names, definitions, HPO phenotype associations
44
+ - [Human Phenotype Ontology](https://hpo.jax.org/) — 8,701 standardised phenotype terms
45
+ - [BioLORD-2023](https://huggingface.co/FremyCompany/BioLORD-2023) — biomedical sentence encoder
46
+
47
+ ## Startup Note
48
+
49
+ The pipeline loads the BioLORD model and 11,456-disease knowledge graph on first request. Allow ~30 seconds after the Space starts before submitting a note.
backend/Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # System deps for lxml, torch, etc.
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ gcc \
8
+ libxml2-dev \
9
+ libxslt-dev \
10
+ curl \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ COPY requirements.txt .
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ COPY scripts/ ./scripts/
17
+ COPY ../.env .env 2>/dev/null || true
18
+
19
+ ENV PYTHONUNBUFFERED=1
20
+
21
+ CMD ["python", "scripts/hello_world.py"]
backend/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # RareDx API package
backend/api/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (163 Bytes). View file
 
backend/api/__pycache__/hallucination_guard.cpython-310.pyc ADDED
Binary file (7.05 kB). View file
 
backend/api/__pycache__/main.cpython-310.pyc ADDED
Binary file (4.25 kB). View file
 
backend/api/__pycache__/models.cpython-310.pyc ADDED
Binary file (2.28 kB). View file
 
backend/api/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (6.55 kB). View file
 
backend/api/hallucination_guard.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ hallucination_guard.py — FusionNode
3
+ ------------------------------------
4
+ Post-RRF evidence validation for diagnostic candidates.
5
+
6
+ A "hallucination" in this retrieval pipeline means a disease appears in the
7
+ ranked list without adequate grounding in *either* the graph or the vector
8
+ store — it floated up via RRF scoring alone, not because the evidence is
9
+ genuinely strong.
10
+
11
+ Three rules evaluated per candidate:
12
+
13
+ Rule 1 — Graph-only candidates (no ChromaDB match):
14
+ Must have match_count >= min_graph_matches.
15
+ A disease sharing only 1 HPO term with the query (out of 4-6) is
16
+ coincidental overlap; every disease eventually shares *some* HPO term
17
+ with any symptom list.
18
+
19
+ Rule 2 — Vector-only candidates (no graph match):
20
+ Must have chroma_sim >= min_vector_sim.
21
+ Moderate BioLORD similarity (0.60-0.64) can come from superficial
22
+ name/description overlap. High confidence (>=0.65) indicates the
23
+ disease's full phenotype description genuinely matches the query.
24
+
25
+ Rule 3 — Frequency validation (graph candidates with matches):
26
+ At least one matched HPO term must be "frequent" or "very frequent"
27
+ in that disease (frequency_order <= 2, i.e. >=30% prevalence).
28
+ A disease where all matched symptoms are "rare" (4-1%) is an unlikely
29
+ diagnosis even if the HPO terms technically overlap.
30
+
31
+ Overlap bonus — candidates appearing in BOTH graph and vector rankings
32
+ always pass rules 1 and 2. Two independent retrieval systems agreeing
33
+ is strong evidence; we only apply Rule 3.
34
+
35
+ Flagged candidates are returned with `hallucination_flag=True` and a
36
+ `flag_reason` string explaining which rule failed. They are NOT removed
37
+ from the response — the caller decides whether to surface or hide them.
38
+ """
39
+
40
+ from dataclasses import dataclass, field
41
+
42
+
43
+ # Frequency order → label mapping (from ingest_hpo.py)
44
+ # 1 = Very frequent (>=80%), 2 = Frequent (30-79%)
45
+ FREQUENT_THRESHOLD = 2
46
+
47
+
48
+ @dataclass
49
+ class GuardResult:
50
+ candidate: dict
51
+ passed: bool
52
+ flag_reason: str | None # None when passed=True
53
+ evidence_score: float # 0.0–1.0 composite evidence strength
54
+
55
+
56
+ class FusionNode:
57
+ """
58
+ Evaluates each RRF candidate and flags those without sufficient evidence.
59
+
60
+ Parameters
61
+ ----------
62
+ min_graph_matches : int
63
+ Minimum HPO term matches required for graph-only candidates. Default 2.
64
+ min_vector_sim : float
65
+ Minimum cosine similarity required for vector-only candidates. Default 0.65.
66
+ require_frequent_match : bool
67
+ If True, graph candidates must have ≥1 HPO match that is "frequent"
68
+ or "very frequent" in the disease. Default True.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ min_graph_matches: int = 2,
74
+ min_vector_sim: float = 0.65,
75
+ require_frequent_match: bool = True,
76
+ ) -> None:
77
+ self.min_graph_matches = min_graph_matches
78
+ self.min_vector_sim = min_vector_sim
79
+ self.require_frequent_match = require_frequent_match
80
+
81
+ # ------------------------------------------------------------------
82
+ # Public API
83
+ # ------------------------------------------------------------------
84
+
85
+ def evaluate(self, candidate: dict, total_query_terms: int) -> GuardResult:
86
+ """Evaluate a single candidate. Returns a GuardResult."""
87
+ has_graph = candidate.get("graph_rank") is not None
88
+ has_vector = candidate.get("chroma_rank") is not None
89
+ matches = candidate.get("graph_matches") or 0
90
+ sim = candidate.get("chroma_sim") or 0.0
91
+ matched_hpo = candidate.get("matched_hpo", [])
92
+
93
+ evidence_score = self._compute_evidence(
94
+ has_graph, has_vector, matches, total_query_terms, sim, matched_hpo
95
+ )
96
+
97
+ # Overlap candidates — always pass rules 1 & 2
98
+ if has_graph and has_vector:
99
+ if self.require_frequent_match and matched_hpo:
100
+ ok, reason = self._check_frequency(matched_hpo)
101
+ if not ok:
102
+ return GuardResult(candidate, False, reason, evidence_score)
103
+ return GuardResult(candidate, True, None, evidence_score)
104
+
105
+ # Graph-only
106
+ if has_graph and not has_vector:
107
+ if matches < self.min_graph_matches:
108
+ reason = (
109
+ f"Rule 1: graph-only with only {matches}/{total_query_terms} "
110
+ f"HPO matches (minimum {self.min_graph_matches} required)"
111
+ )
112
+ return GuardResult(candidate, False, reason, evidence_score)
113
+
114
+ if self.require_frequent_match and matched_hpo:
115
+ ok, reason = self._check_frequency(matched_hpo)
116
+ if not ok:
117
+ return GuardResult(candidate, False, reason, evidence_score)
118
+
119
+ return GuardResult(candidate, True, None, evidence_score)
120
+
121
+ # Vector-only
122
+ if has_vector and not has_graph:
123
+ if sim < self.min_vector_sim:
124
+ reason = (
125
+ f"Rule 2: vector-only with similarity {sim:.4f} "
126
+ f"(minimum {self.min_vector_sim} required)"
127
+ )
128
+ return GuardResult(candidate, False, reason, evidence_score)
129
+ return GuardResult(candidate, True, None, evidence_score)
130
+
131
+ # Neither — shouldn't happen after RRF, but guard defensively
132
+ return GuardResult(
133
+ candidate, False,
134
+ "Rule 0: candidate has neither graph nor vector evidence",
135
+ 0.0,
136
+ )
137
+
138
+ def filter(
139
+ self,
140
+ candidates: list[dict],
141
+ total_query_terms: int,
142
+ ) -> tuple[list[dict], list[dict]]:
143
+ """
144
+ Evaluate all candidates and split into (passed, flagged) lists.
145
+ Both lists preserve original RRF rank order.
146
+ The `candidate` dicts are mutated in-place to add:
147
+ - hallucination_flag: bool
148
+ - flag_reason: str | None
149
+ - evidence_score: float
150
+ """
151
+ passed, flagged = [], []
152
+
153
+ for c in candidates:
154
+ result = self.evaluate(c, total_query_terms)
155
+ c["hallucination_flag"] = not result.passed
156
+ c["flag_reason"] = result.flag_reason
157
+ c["evidence_score"] = round(result.evidence_score, 4)
158
+
159
+ if result.passed:
160
+ passed.append(c)
161
+ else:
162
+ flagged.append(c)
163
+
164
+ return passed, flagged
165
+
166
+ # ------------------------------------------------------------------
167
+ # Private helpers
168
+ # ------------------------------------------------------------------
169
+
170
+ def _check_frequency(
171
+ self, matched_hpo: list[dict]
172
+ ) -> tuple[bool, str | None]:
173
+ """
174
+ Rule 3: at least one matched HPO term must be frequent/very frequent.
175
+ matched_hpo items carry a 'frequency_label' string from graph edges.
176
+ """
177
+ for h in matched_hpo:
178
+ label = h.get("frequency_label", "").lower()
179
+ # Accept "very frequent", "frequent", "obligate"
180
+ if any(kw in label for kw in ("very frequent", "frequent", "obligate")):
181
+ return True, None
182
+
183
+ terms_str = ", ".join(h.get("term", "") for h in matched_hpo)
184
+ return False, (
185
+ f"Rule 3: all matched HPO terms are rare/occasional in this disease "
186
+ f"({terms_str}) — unlikely diagnosis"
187
+ )
188
+
189
+ @staticmethod
190
+ def _compute_evidence(
191
+ has_graph: bool,
192
+ has_vector: bool,
193
+ matches: int,
194
+ total: int,
195
+ sim: float,
196
+ matched_hpo: list[dict],
197
+ ) -> float:
198
+ """
199
+ Composite evidence score 0–1.
200
+ Overlap (both signals) gets a 1.25× bonus capped at 1.0.
201
+ """
202
+ graph_score = 0.0
203
+ if has_graph and total > 0:
204
+ # Base: fraction of query terms matched
205
+ overlap_ratio = matches / total
206
+ # Frequency bonus: proportion of matched terms that are frequent
207
+ freq_count = sum(
208
+ 1 for h in matched_hpo
209
+ if any(kw in h.get("frequency_label", "").lower()
210
+ for kw in ("very frequent", "frequent", "obligate"))
211
+ )
212
+ freq_ratio = freq_count / matches if matches else 0.0
213
+ graph_score = 0.6 * overlap_ratio + 0.4 * freq_ratio
214
+
215
+ vector_score = sim if has_vector else 0.0
216
+
217
+ if has_graph and has_vector:
218
+ raw = (graph_score + vector_score) / 2 * 1.25
219
+ elif has_graph:
220
+ raw = graph_score
221
+ else:
222
+ raw = vector_score
223
+
224
+ return min(raw, 1.0)
backend/api/main.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ main.py — RareDx FastAPI application.
3
+
4
+ Run with:
5
+ uvicorn backend.api.main:app --reload --port 8080
6
+
7
+ Or from the project root:
8
+ python -m uvicorn backend.api.main:app --reload --port 8080
9
+
10
+ Endpoints:
11
+ POST /diagnose — clinical note → differential diagnosis
12
+ GET /health — liveness check
13
+ GET /hpo/search?q=... — debug: find HPO terms by keyword
14
+ """
15
+
16
+ import sys
17
+ from contextlib import asynccontextmanager
18
+ from pathlib import Path
19
+
20
+ from fastapi import FastAPI, HTTPException
21
+ from fastapi.middleware.cors import CORSMiddleware
22
+
23
+ # Ensure scripts/ importable
24
+ sys.path.insert(0, str(Path(__file__).parents[1] / "scripts"))
25
+
26
+ from .models import DiagnoseRequest, DiagnoseResponse, Candidate, HPOMatch
27
+ from .pipeline import DiagnosisPipeline
28
+
29
+ # Pipeline is loaded once at startup (model loading takes ~3s)
30
+ pipeline: DiagnosisPipeline | None = None
31
+
32
+
33
+ @asynccontextmanager
34
+ async def lifespan(app: FastAPI):
35
+ global pipeline
36
+ print("Starting up RareDx API...")
37
+ pipeline = DiagnosisPipeline()
38
+ print("API ready.")
39
+ yield
40
+ print("Shutting down.")
41
+
42
+
43
+ app = FastAPI(
44
+ title="RareDx API",
45
+ description=(
46
+ "Multi-agent clinical AI for rare disease diagnosis. "
47
+ "Combines knowledge graph (Orphanet/HPO) with BioLORD-2023 "
48
+ "semantic embeddings to generate differential diagnoses from "
49
+ "plain-text clinical notes."
50
+ ),
51
+ version="0.2.0",
52
+ lifespan=lifespan,
53
+ )
54
+
55
+ app.add_middleware(
56
+ CORSMiddleware,
57
+ allow_origins=["*"],
58
+ allow_methods=["*"],
59
+ allow_headers=["*"],
60
+ )
61
+
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # Endpoints
65
+ # ---------------------------------------------------------------------------
66
+
67
+ @app.get("/health")
68
+ def health():
69
+ return {
70
+ "status": "ok",
71
+ "pipeline_ready": pipeline is not None,
72
+ "graph_backend": pipeline.graph_backend if pipeline else None,
73
+ "chroma_backend": pipeline.chroma_backend if pipeline else None,
74
+ }
75
+
76
+
77
+ @app.post("/diagnose", response_model=DiagnoseResponse)
78
+ def diagnose(request: DiagnoseRequest):
79
+ if pipeline is None:
80
+ raise HTTPException(status_code=503, detail="Pipeline not initialised.")
81
+
82
+ result = pipeline.diagnose(
83
+ note=request.note,
84
+ top_n=request.top_n,
85
+ threshold=request.threshold,
86
+ )
87
+
88
+ def _to_candidate(c: dict) -> Candidate:
89
+ return Candidate(
90
+ rank = c["rank"],
91
+ orpha_code = c["orpha_code"],
92
+ name = c["name"],
93
+ definition = c.get("definition") or None,
94
+ rrf_score = c["rrf_score"],
95
+ graph_rank = c.get("graph_rank"),
96
+ chroma_rank = c.get("chroma_rank"),
97
+ graph_matches = c.get("graph_matches"),
98
+ chroma_sim = c.get("chroma_sim"),
99
+ matched_hpo = c.get("matched_hpo", []),
100
+ hallucination_flag = c.get("hallucination_flag", False),
101
+ flag_reason = c.get("flag_reason"),
102
+ evidence_score = c.get("evidence_score", 0.0),
103
+ )
104
+
105
+ candidates = [_to_candidate(c) for c in result["candidates"]]
106
+ passed_candidates = [_to_candidate(c) for c in result["passed_candidates"]]
107
+ flagged_candidates= [_to_candidate(c) for c in result["flagged_candidates"]]
108
+ hpo_matches = [HPOMatch(**m) for m in result["hpo_matches"]]
109
+ top = passed_candidates[0] if passed_candidates else (candidates[0] if candidates else None)
110
+
111
+ return DiagnoseResponse(
112
+ note = result["note"],
113
+ phrases_extracted = result["phrases_extracted"],
114
+ hpo_matches = hpo_matches,
115
+ hpo_ids_used = result["hpo_ids_used"],
116
+ candidates = candidates,
117
+ passed_candidates = passed_candidates,
118
+ flagged_candidates = flagged_candidates,
119
+ top_diagnosis = top,
120
+ graph_backend = result["graph_backend"],
121
+ chroma_backend = result["chroma_backend"],
122
+ elapsed_seconds = result["elapsed_seconds"],
123
+ )
124
+
125
+
126
+ @app.get("/hpo/search")
127
+ def hpo_search(q: str, limit: int = 10):
128
+ """Debug endpoint: find HPO terms by keyword in graph store."""
129
+ if pipeline is None:
130
+ raise HTTPException(status_code=503, detail="Pipeline not initialised.")
131
+
132
+ store = pipeline.graph_store
133
+ q_lower = q.lower()
134
+ results = []
135
+ for _, attrs in store.graph.nodes(data=True):
136
+ if attrs.get("type") == "HPOTerm":
137
+ if q_lower in attrs.get("term", "").lower():
138
+ results.append({"hpo_id": attrs["hpo_id"], "term": attrs["term"]})
139
+ if len(results) >= limit:
140
+ break
141
+ return {"query": q, "results": results}
backend/api/models.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ models.py — Pydantic request / response models for the RareDx diagnosis API.
3
+ """
4
+
5
+ from pydantic import BaseModel, Field
6
+ from typing import Optional
7
+
8
+
9
+ class DiagnoseRequest(BaseModel):
10
+ note: str = Field(
11
+ ...,
12
+ min_length=10,
13
+ description="Plain-text clinical note describing the patient's presentation.",
14
+ examples=["18 year old male, extremely tall, displaced lens, heart murmur, scoliosis"],
15
+ )
16
+ top_n: int = Field(default=10, ge=1, le=50, description="Max candidates to return.")
17
+ threshold: float = Field(
18
+ default=0.55, ge=0.3, le=0.95,
19
+ description="Minimum BioLORD cosine similarity to accept an HPO match.",
20
+ )
21
+
22
+
23
+ class HPOMatch(BaseModel):
24
+ phrase: str
25
+ hpo_id: str
26
+ term: str
27
+ score: float
28
+
29
+
30
+ class Candidate(BaseModel):
31
+ rank: int
32
+ orpha_code: str
33
+ name: str
34
+ definition: Optional[str]
35
+ rrf_score: float
36
+ graph_rank: Optional[int]
37
+ chroma_rank: Optional[int]
38
+ graph_matches: Optional[int] # number of HPO terms matched in graph
39
+ chroma_sim: Optional[float] # cosine similarity from ChromaDB
40
+ matched_hpo: list[dict] = [] # HPO terms matched via graph
41
+ hallucination_flag: bool = False
42
+ flag_reason: Optional[str] = None
43
+ evidence_score: float = 0.0
44
+
45
+
46
+ class DiagnoseResponse(BaseModel):
47
+ note: str
48
+ phrases_extracted: list[str]
49
+ hpo_matches: list[HPOMatch]
50
+ hpo_ids_used: list[str]
51
+ candidates: list[Candidate] # all candidates (flag fields attached)
52
+ passed_candidates: list[Candidate] # guard passed
53
+ flagged_candidates: list[Candidate] # guard flagged
54
+ top_diagnosis: Optional[Candidate]
55
+ graph_backend: str
56
+ chroma_backend: str
57
+ elapsed_seconds: float
backend/api/pipeline.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ pipeline.py
3
+ -----------
4
+ DiagnosisPipeline — the core reasoning engine for RareDx.
5
+
6
+ Shared between the FastAPI app (loaded once at startup) and the
7
+ milestone_2b.py script (instantiated directly).
8
+
9
+ Steps:
10
+ 1. SymptomParser → map clinical note phrases to HPO IDs (BioLORD semantic)
11
+ 2. GraphSearch → MANIFESTS_AS traversal ranked by phenotype overlap
12
+ 3. ChromaSearch → BioLORD semantic search over HPO-enriched embeddings
13
+ 4. RRF Fusion → merge both rankings via Reciprocal Rank Fusion
14
+ 5. FusionNode → hallucination guard: flag candidates lacking evidence
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import time
20
+ import concurrent.futures
21
+ from pathlib import Path
22
+
23
+ import chromadb
24
+ from chromadb.config import Settings
25
+ from sentence_transformers import SentenceTransformer
26
+ from dotenv import load_dotenv
27
+
28
+ load_dotenv(Path(__file__).parents[2] / ".env")
29
+
30
+ # Ensure scripts/ and api/ are importable
31
+ SCRIPTS_DIR = Path(__file__).parents[1] / "scripts"
32
+ API_DIR = Path(__file__).parent
33
+ sys.path.insert(0, str(SCRIPTS_DIR))
34
+ sys.path.insert(0, str(API_DIR))
35
+
36
+ CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
37
+ CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000"))
38
+ COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases")
39
+ EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
40
+ CHROMA_PERSIST = Path(__file__).parents[2] / "data" / "chromadb"
41
+
42
+ RRF_K = 60 # Standard constant for Reciprocal Rank Fusion
43
+
44
+
45
+ class DiagnosisPipeline:
46
+ """
47
+ Initialise once per process; call .diagnose(note) for each request.
48
+ Thread-safe: graph traversal and ChromaDB query run in parallel threads.
49
+ """
50
+
51
+ def __init__(self) -> None:
52
+ print("Initialising DiagnosisPipeline...")
53
+
54
+ # BioLORD model (shared by symptom parser + ChromaDB query)
55
+ print(" Loading BioLORD-2023...")
56
+ self.model = SentenceTransformer(EMBED_MODEL)
57
+
58
+ # SymptomParser (also builds / loads HPO index)
59
+ from symptom_parser import SymptomParser
60
+ self.symptom_parser = SymptomParser(self.model)
61
+
62
+ # ChromaDB client
63
+ self.chroma_col, self.chroma_backend = self._init_chroma()
64
+
65
+ # Graph store
66
+ from graph_store import LocalGraphStore
67
+ self.graph_store = LocalGraphStore()
68
+ self.graph_backend = "LocalGraphStore (JSON)"
69
+
70
+ # Hallucination guard
71
+ from hallucination_guard import FusionNode
72
+ self.fusion_node = FusionNode(
73
+ min_graph_matches=2,
74
+ min_vector_sim=0.65,
75
+ require_frequent_match=True,
76
+ )
77
+
78
+ print("Pipeline ready.")
79
+
80
+ # ------------------------------------------------------------------
81
+ # Initialisation helpers
82
+ # ------------------------------------------------------------------
83
+
84
+ def _init_chroma(self):
85
+ try:
86
+ client = chromadb.HttpClient(
87
+ host=CHROMA_HOST, port=CHROMA_PORT,
88
+ settings=Settings(anonymized_telemetry=False),
89
+ )
90
+ client.heartbeat()
91
+ col = client.get_collection(COLLECTION_NAME)
92
+ return col, "ChromaDB HTTP (Docker)"
93
+ except Exception:
94
+ client = chromadb.PersistentClient(
95
+ path=str(CHROMA_PERSIST),
96
+ settings=Settings(anonymized_telemetry=False),
97
+ )
98
+ col = client.get_collection(COLLECTION_NAME)
99
+ return col, "ChromaDB Embedded"
100
+
101
+ # ------------------------------------------------------------------
102
+ # Core diagnosis
103
+ # ------------------------------------------------------------------
104
+
105
+ def diagnose(
106
+ self,
107
+ note: str,
108
+ top_n: int = 10,
109
+ threshold: float = 0.55,
110
+ ) -> dict:
111
+ t_start = time.time()
112
+
113
+ # Step 1: symptom parsing
114
+ self.symptom_parser.threshold = threshold
115
+ hpo_matches = self.symptom_parser.parse(note)
116
+ hpo_ids = [m.hpo_id for m in hpo_matches]
117
+ phrases = [m.phrase for m in hpo_matches]
118
+
119
+ # Steps 2 & 3: parallel graph + vector search
120
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
121
+ graph_fut = pool.submit(self._graph_search, hpo_ids, top_n)
122
+ chroma_fut = pool.submit(self._chroma_search, note, top_n)
123
+ graph_hits = graph_fut.result()
124
+ chroma_hits = chroma_fut.result()
125
+
126
+ # Step 4: RRF fusion
127
+ fused = self._rrf_fuse(graph_hits, chroma_hits)[:top_n]
128
+
129
+ # Step 5: Hallucination guard
130
+ passed, flagged = self.fusion_node.filter(fused, total_query_terms=len(hpo_ids))
131
+
132
+ # Top diagnosis is the highest-ranked *passed* candidate;
133
+ # fall back to highest-ranked overall if everything is flagged.
134
+ top = passed[0] if passed else (fused[0] if fused else None)
135
+
136
+ return {
137
+ "note": note,
138
+ "phrases_extracted": phrases,
139
+ "hpo_matches": [
140
+ {"phrase": m.phrase, "hpo_id": m.hpo_id,
141
+ "term": m.term, "score": m.score}
142
+ for m in hpo_matches
143
+ ],
144
+ "hpo_ids_used": hpo_ids,
145
+ "candidates": fused, # all candidates, flag fields attached
146
+ "passed_candidates": passed,
147
+ "flagged_candidates": flagged,
148
+ "top_diagnosis": top,
149
+ "graph_backend": self.graph_backend,
150
+ "chroma_backend": self.chroma_backend,
151
+ "elapsed_seconds": round(time.time() - t_start, 3),
152
+ }
153
+
154
+ # ------------------------------------------------------------------
155
+ # Graph traversal
156
+ # ------------------------------------------------------------------
157
+
158
+ def _graph_search(self, hpo_ids: list[str], top_n: int) -> list[dict]:
159
+ if not hpo_ids:
160
+ return []
161
+ return self.graph_store.find_diseases_by_hpo(hpo_ids, top_n=top_n)
162
+
163
+ # ------------------------------------------------------------------
164
+ # ChromaDB semantic search
165
+ # ------------------------------------------------------------------
166
+
167
+ def _chroma_search(self, note: str, top_n: int) -> list[dict]:
168
+ emb = self.model.encode([note], normalize_embeddings=True)
169
+ results = self.chroma_col.query(
170
+ query_embeddings=emb.tolist(),
171
+ n_results=top_n,
172
+ include=["metadatas", "distances"],
173
+ )
174
+ hits = []
175
+ for meta, dist in zip(results["metadatas"][0], results["distances"][0]):
176
+ hits.append({
177
+ "orpha_code": meta.get("orpha_code"),
178
+ "name": meta.get("name"),
179
+ "definition": meta.get("definition", ""),
180
+ "cosine_similarity": round(1 - dist, 4),
181
+ })
182
+ return hits
183
+
184
+ # ------------------------------------------------------------------
185
+ # RRF fusion
186
+ # ------------------------------------------------------------------
187
+
188
+ def _rrf_fuse(
189
+ self,
190
+ graph_results: list[dict],
191
+ chroma_results: list[dict],
192
+ ) -> list[dict]:
193
+ scores: dict[str, dict] = {}
194
+
195
+ for rank, d in enumerate(graph_results, 1):
196
+ key = str(d["orpha_code"])
197
+ if key not in scores:
198
+ scores[key] = self._base_entry(d)
199
+ scores[key]["rrf_score"] += 1 / (RRF_K + rank)
200
+ scores[key]["graph_rank"] = rank
201
+ scores[key]["graph_matches"] = d.get("match_count", 0)
202
+ scores[key]["matched_hpo"] = d.get("matched_hpo", [])
203
+
204
+ for rank, d in enumerate(chroma_results, 1):
205
+ key = str(d["orpha_code"])
206
+ if key not in scores:
207
+ scores[key] = self._base_entry(d)
208
+ scores[key]["rrf_score"] += 1 / (RRF_K + rank)
209
+ scores[key]["chroma_rank"] = rank
210
+ scores[key]["chroma_sim"] = d.get("cosine_similarity")
211
+
212
+ ranked = sorted(scores.values(), key=lambda x: x["rrf_score"], reverse=True)
213
+
214
+ for i, entry in enumerate(ranked, 1):
215
+ entry["rank"] = i
216
+ entry["rrf_score"] = round(entry["rrf_score"], 5)
217
+ return ranked
218
+
219
+ @staticmethod
220
+ def _base_entry(d: dict) -> dict:
221
+ return {
222
+ "rank": 0,
223
+ "orpha_code": str(d["orpha_code"]),
224
+ "name": d.get("name", ""),
225
+ "definition": d.get("definition", ""),
226
+ "rrf_score": 0.0,
227
+ "graph_rank": None,
228
+ "chroma_rank": None,
229
+ "graph_matches": None,
230
+ "chroma_sim": None,
231
+ "matched_hpo": [],
232
+ }
backend/dashboard/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # RareDx dashboard package
backend/dashboard/__pycache__/charts.cpython-310.pyc ADDED
Binary file (7.83 kB). View file
 
backend/dashboard/app.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py — RareDx Streamlit Dashboard
3
+ -------------------------------------
4
+ Run with:
5
+ streamlit run backend/dashboard/app.py
6
+
7
+ Requires the FastAPI server running on localhost:8080:
8
+ uvicorn backend.api.main:app --port 8080
9
+
10
+ Falls back to direct pipeline import if the API is unreachable.
11
+
12
+ Tabs:
13
+ 1. Results — ranked differential diagnosis + Why button per candidate
14
+ 2. Evidence Map — bipartite HPO-to-disease graph + guard donut + RRF breakdown
15
+ 3. Agent Audit Trail — step-by-step pipeline trace with timings
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import sys
21
+ import time
22
+ import json
23
+ from pathlib import Path
24
+
25
+ import requests
26
+ import streamlit as st
27
+
28
+ # Make backend packages importable when run from project root
29
+ ROOT = Path(__file__).parents[2]
30
+ sys.path.insert(0, str(ROOT / "backend" / "scripts"))
31
+ sys.path.insert(0, str(ROOT / "backend" / "api"))
32
+ sys.path.insert(0, str(ROOT / "backend"))
33
+
34
+ from charts import evidence_bar, evidence_map, guard_donut, rrf_waterfall
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Config
38
+ # ---------------------------------------------------------------------------
39
+
40
+ API_BASE = "http://localhost:8080"
41
+
42
+ DEFAULT_NOTE = (
43
+ "18 year old male, extremely tall, displaced lens in left eye, "
44
+ "heart murmur, flexible joints, scoliosis"
45
+ )
46
+
47
+ st.set_page_config(
48
+ page_title="RareDx — Rare Disease Diagnostic AI",
49
+ layout="wide",
50
+ initial_sidebar_state="expanded",
51
+ )
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # CSS — minimal dark-card theme
55
+ # ---------------------------------------------------------------------------
56
+
57
+ st.markdown("""
58
+ <style>
59
+ .block-container { padding-top: 1.5rem; }
60
+ .metric-card {
61
+ background: #1e293b; border-radius: 10px; padding: 14px 18px;
62
+ margin: 6px 0; border-left: 4px solid #6366f1;
63
+ }
64
+ .candidate-pass {
65
+ background: #052e16; border-left: 4px solid #22c55e;
66
+ border-radius: 8px; padding: 10px 14px; margin: 6px 0;
67
+ }
68
+ .candidate-flag {
69
+ background: #451a03; border-left: 4px solid #f59e0b;
70
+ border-radius: 8px; padding: 10px 14px; margin: 6px 0;
71
+ }
72
+ .top-diagnosis {
73
+ background: #4a044e; border-left: 4px solid #ec4899;
74
+ border-radius: 8px; padding: 12px 16px; margin: 8px 0;
75
+ }
76
+ .hpo-chip {
77
+ display: inline-block; background: #1e1b4b; color: #a5b4fc;
78
+ border-radius: 12px; padding: 2px 10px; margin: 3px;
79
+ font-size: 0.82em;
80
+ }
81
+ .audit-step {
82
+ background: #0f172a; border-radius: 8px; padding: 12px 16px;
83
+ margin: 8px 0; border: 1px solid #334155;
84
+ }
85
+ .step-header { font-weight: 600; color: #94a3b8; font-size: 0.85em;
86
+ text-transform: uppercase; letter-spacing: 0.05em; }
87
+ </style>
88
+ """, unsafe_allow_html=True)
89
+
90
+
91
+ # ---------------------------------------------------------------------------
92
+ # API / pipeline call
93
+ # ---------------------------------------------------------------------------
94
+
95
+ @st.cache_resource(show_spinner="Loading diagnostic pipeline...")
96
+ def get_pipeline():
97
+ """Load the DiagnosisPipeline once and cache across requests."""
98
+ from pipeline import DiagnosisPipeline
99
+ return DiagnosisPipeline()
100
+
101
+
102
+ def call_api(note: str, top_n: int = 15, threshold: float = 0.52) -> dict | None:
103
+ """
104
+ Try FastAPI first; fall back to direct pipeline call.
105
+ Returns the raw result dict or None on error.
106
+ """
107
+ try:
108
+ resp = requests.post(
109
+ f"{API_BASE}/diagnose",
110
+ json={"note": note, "top_n": top_n, "threshold": threshold},
111
+ timeout=30,
112
+ )
113
+ resp.raise_for_status()
114
+ data = resp.json()
115
+ data["_backend"] = "FastAPI"
116
+ return data
117
+ except Exception as api_err:
118
+ st.sidebar.warning(f"FastAPI not reachable ({api_err}). Running pipeline locally.")
119
+ try:
120
+ pipeline = get_pipeline()
121
+ data = pipeline.diagnose(note, top_n=top_n, threshold=threshold)
122
+ # Normalise: pipeline returns dataclass-style dicts; convert hpo_matches
123
+ data["_backend"] = "Direct (local)"
124
+ return data
125
+ except Exception as local_err:
126
+ st.error(f"Pipeline error: {local_err}")
127
+ return None
128
+
129
+
130
+ # ---------------------------------------------------------------------------
131
+ # Sidebar
132
+ # ---------------------------------------------------------------------------
133
+
134
+ with st.sidebar:
135
+ st.markdown("## RareDx")
136
+ st.caption("Rare Disease Diagnostic AI — Week 3")
137
+ st.divider()
138
+
139
+ note_input = st.text_area(
140
+ "Clinical Note",
141
+ value=DEFAULT_NOTE,
142
+ height=160,
143
+ placeholder="Describe the patient presentation...",
144
+ )
145
+
146
+ col1, col2 = st.columns(2)
147
+ top_n = col1.number_input("Top N", min_value=5, max_value=30, value=15)
148
+ threshold = col2.number_input("HPO thresh", min_value=0.4, max_value=0.9, value=0.52, step=0.01)
149
+
150
+ run_btn = st.button("Run Diagnosis", type="primary", use_container_width=True)
151
+
152
+ st.divider()
153
+
154
+ # API status
155
+ try:
156
+ hc = requests.get(f"{API_BASE}/health", timeout=2).json()
157
+ st.success(f"API: {hc['status']} | {hc.get('graph_backend','?')[:20]}")
158
+ except Exception:
159
+ st.warning("API offline — will use local pipeline")
160
+
161
+
162
+ # ---------------------------------------------------------------------------
163
+ # Session state
164
+ # ---------------------------------------------------------------------------
165
+
166
+ if "result" not in st.session_state:
167
+ st.session_state.result = None
168
+ st.session_state.audit_log = []
169
+
170
+ if run_btn:
171
+ st.session_state.audit_log = []
172
+ t0 = time.time()
173
+
174
+ with st.spinner("Parsing symptoms and querying knowledge graph..."):
175
+ result = call_api(note_input, top_n=int(top_n), threshold=float(threshold))
176
+
177
+ if result:
178
+ result["_wall_seconds"] = round(time.time() - t0, 2)
179
+ st.session_state.result = result
180
+
181
+
182
+ # ---------------------------------------------------------------------------
183
+ # Main content
184
+ # ---------------------------------------------------------------------------
185
+
186
+ result = st.session_state.result
187
+
188
+ if result is None:
189
+ st.markdown("## Enter a clinical note and click **Run Diagnosis**")
190
+ st.markdown("""
191
+ **Example queries:**
192
+ - `18 year old male, extremely tall, displaced lens, heart murmur, flexible joints, scoliosis`
193
+ - `young woman, muscle weakness, fatigue, difficulty swallowing, drooping eyelids`
194
+ - `child with recurrent infections, absent lymph nodes, low IgG, IgA, IgM`
195
+ """)
196
+ st.stop()
197
+
198
+
199
+ # ---------------------------------------------------------------------------
200
+ # Header metrics
201
+ # ---------------------------------------------------------------------------
202
+
203
+ candidates = result.get("candidates", [])
204
+ passed_candidates = result.get("passed_candidates", [])
205
+ flagged_candidates= result.get("flagged_candidates", [])
206
+ hpo_matches = result.get("hpo_matches", [])
207
+ top = result.get("top_diagnosis") or {}
208
+
209
+ c1, c2, c3, c4, c5 = st.columns(5)
210
+ c1.metric("HPO Terms", len(hpo_matches))
211
+ c2.metric("Candidates", len(candidates))
212
+ c3.metric("Passed Guard", len(passed_candidates))
213
+ c4.metric("Flagged", len(flagged_candidates))
214
+ c5.metric("Elapsed (s)", result.get("elapsed_seconds") or result.get("_wall_seconds", "?"))
215
+
216
+ st.divider()
217
+
218
+ # ---------------------------------------------------------------------------
219
+ # Tabs
220
+ # ---------------------------------------------------------------------------
221
+
222
+ tab_results, tab_map, tab_audit = st.tabs([
223
+ "Results",
224
+ "Evidence Map",
225
+ "Agent Audit Trail",
226
+ ])
227
+
228
+
229
+ # ============================================================
230
+ # TAB 1 — RESULTS
231
+ # ============================================================
232
+
233
+ with tab_results:
234
+ # Top diagnosis banner
235
+ if top:
236
+ flag_note = " — FLAGGED (fallback)" if top.get("hallucination_flag") else ""
237
+ st.markdown(
238
+ f'<div class="top-diagnosis">'
239
+ f'<span style="color:#f9a8d4;font-size:0.8em;text-transform:uppercase;'
240
+ f'letter-spacing:0.1em">Top Diagnosis{flag_note}</span><br>'
241
+ f'<span style="font-size:1.5em;font-weight:700">{top.get("name","")}</span>'
242
+ f'&nbsp;&nbsp;<span style="color:#94a3b8">ORPHA:{top.get("orpha_code","")}</span><br>'
243
+ f'<span style="color:#94a3b8">Evidence score: {top.get("evidence_score",0):.3f}'
244
+ f' | RRF: {top.get("rrf_score",0):.5f}</span>'
245
+ f'</div>',
246
+ unsafe_allow_html=True,
247
+ )
248
+
249
+ # Evidence bar chart
250
+ st.plotly_chart(evidence_bar(candidates), use_container_width=True)
251
+
252
+ # Candidate list with Why? expanders
253
+ st.subheader("All Candidates")
254
+
255
+ for c in candidates:
256
+ is_top = str(c.get("orpha_code")) == str(top.get("orpha_code", ""))
257
+ is_flagged= c.get("hallucination_flag", False)
258
+ css_class = "top-diagnosis" if is_top else ("candidate-flag" if is_flagged else "candidate-pass")
259
+ flag_badge= " FLAGGED" if is_flagged else ""
260
+
261
+ col_left, col_right = st.columns([8, 2])
262
+
263
+ with col_left:
264
+ st.markdown(
265
+ f'<div class="{css_class}">'
266
+ f'<b>#{c["rank"]} {c["name"]}</b>{flag_badge}'
267
+ f'&nbsp;&nbsp;<span style="color:#94a3b8">ORPHA:{c["orpha_code"]}</span>'
268
+ f'</div>',
269
+ unsafe_allow_html=True,
270
+ )
271
+
272
+ with col_right:
273
+ with st.expander("Why?"):
274
+ # Graph evidence
275
+ st.markdown("**Graph phenotype evidence**")
276
+ if c.get("graph_rank"):
277
+ matched = c.get("matched_hpo", [])
278
+ if matched:
279
+ chips = " ".join(
280
+ f'<span class="hpo-chip">{h["term"]} '
281
+ f'<span style="opacity:0.6">({h.get("frequency_label","")[:4]})</span>'
282
+ f'</span>'
283
+ for h in matched
284
+ )
285
+ st.markdown(chips, unsafe_allow_html=True)
286
+ st.caption(
287
+ f"{len(matched)} / {len(hpo_matches)} query terms matched | "
288
+ f"Graph rank #{c['graph_rank']}"
289
+ )
290
+ else:
291
+ st.caption("No matched HPO terms recorded.")
292
+ else:
293
+ st.caption("Not found in graph traversal.")
294
+
295
+ st.divider()
296
+
297
+ # Vector evidence
298
+ st.markdown("**BioLORD semantic similarity**")
299
+ if c.get("chroma_rank"):
300
+ sim = c.get("chroma_sim", 0)
301
+ st.progress(float(sim), text=f"{sim:.4f} cosine similarity (vector rank #{c['chroma_rank']})")
302
+ else:
303
+ st.caption("Not found in vector search.")
304
+
305
+ st.divider()
306
+
307
+ # Hallucination guard
308
+ st.markdown("**Hallucination guard**")
309
+ if is_flagged:
310
+ st.warning(c.get("flag_reason") or "Flagged — insufficient evidence")
311
+ else:
312
+ st.success(f"Passed — evidence score {c.get('evidence_score',0):.3f}")
313
+
314
+ # Orphanet link
315
+ orpha = c.get("orpha_code", "")
316
+ st.markdown(
317
+ f"[View on Orphanet](https://www.orpha.net/en/disease/detail/{orpha})",
318
+ unsafe_allow_html=False,
319
+ )
320
+
321
+
322
+ # ============================================================
323
+ # TAB 2 — EVIDENCE MAP
324
+ # ============================================================
325
+
326
+ with tab_map:
327
+ col_map, col_right = st.columns([3, 1])
328
+
329
+ with col_map:
330
+ st.plotly_chart(evidence_map(result), use_container_width=True)
331
+
332
+ with col_right:
333
+ st.plotly_chart(
334
+ guard_donut(len(passed_candidates), len(flagged_candidates)),
335
+ use_container_width=True,
336
+ )
337
+
338
+ st.plotly_chart(rrf_waterfall(candidates), use_container_width=True)
339
+
340
+ # HPO match table
341
+ st.subheader("Extracted HPO Terms")
342
+ hpo_rows = []
343
+ for m in hpo_matches:
344
+ hpo_rows.append({
345
+ "Phrase": m.get("phrase") or m.get("phrase", ""),
346
+ "HPO ID": m.get("hpo_id") or "",
347
+ "HPO Term": m.get("term") or "",
348
+ "Confidence": round(float(m.get("score", 0)), 4),
349
+ "Type": "single word" if len((m.get("phrase") or "").split()) == 1 else "phrase",
350
+ })
351
+ if hpo_rows:
352
+ st.dataframe(hpo_rows, use_container_width=True, height=220)
353
+
354
+
355
+ # ============================================================
356
+ # TAB 3 — AGENT AUDIT TRAIL
357
+ # ============================================================
358
+
359
+ with tab_audit:
360
+ st.subheader("Pipeline Execution Trace")
361
+
362
+ backend_label = result.get("_backend", result.get("graph_backend", "local"))
363
+
364
+ steps = [
365
+ {
366
+ "step": "1. Clinical Note Input",
367
+ "color": "#6366f1",
368
+ "summary": f"{len(note_input)} characters | Backend: {backend_label}",
369
+ "detail": f'```\n{note_input}\n```',
370
+ },
371
+ {
372
+ "step": "2. Symptom Parser (BioLORD semantic HPO mapping)",
373
+ "color": "#0ea5e9",
374
+ "summary": (
375
+ f"{len(hpo_matches)} HPO terms extracted from "
376
+ f"{len(result.get('phrases_extracted', []))} candidate phrases"
377
+ ),
378
+ "detail": "\n".join(
379
+ f"- **{m.get('phrase','?')}** → `{m.get('hpo_id','')}` "
380
+ f"{m.get('term','')} (score: {m.get('score',0):.4f})"
381
+ for m in hpo_matches
382
+ ) or "_No matches_",
383
+ },
384
+ {
385
+ "step": "3. Graph Traversal (MANIFESTS_AS phenotype matching)",
386
+ "color": "#22c55e",
387
+ "summary": (
388
+ f"{sum(1 for c in candidates if c.get('graph_rank'))} candidates "
389
+ f"from {len(result.get('hpo_ids_used', []))} HPO IDs | "
390
+ f"Backend: {result.get('graph_backend','?')}"
391
+ ),
392
+ "detail": "\n".join(
393
+ f"- **#{c['graph_rank']}** {c['name']} — "
394
+ f"{c.get('graph_matches','?')} phenotype matches"
395
+ for c in sorted(
396
+ [c for c in candidates if c.get("graph_rank")],
397
+ key=lambda x: x["graph_rank"],
398
+ )
399
+ ) or "_No graph results_",
400
+ },
401
+ {
402
+ "step": "4. Semantic Search (BioLORD + ChromaDB)",
403
+ "color": "#8b5cf6",
404
+ "summary": (
405
+ f"{sum(1 for c in candidates if c.get('chroma_rank'))} candidates | "
406
+ f"Backend: {result.get('chroma_backend','?')}"
407
+ ),
408
+ "detail": "\n".join(
409
+ f"- **#{c['chroma_rank']}** {c['name']} — "
410
+ f"similarity {c.get('chroma_sim',0):.4f}"
411
+ for c in sorted(
412
+ [c for c in candidates if c.get("chroma_rank")],
413
+ key=lambda x: x["chroma_rank"],
414
+ )
415
+ ) or "_No vector results_",
416
+ },
417
+ {
418
+ "step": "5. RRF Fusion",
419
+ "color": "#ec4899",
420
+ "summary": f"{len(candidates)} unique candidates after fusion",
421
+ "detail": "\n".join(
422
+ f"- **#{c['rank']}** {c['name']} RRF={c['rrf_score']:.5f} "
423
+ f"G={'#'+str(c['graph_rank']) if c.get('graph_rank') else '—'} "
424
+ f"V={'#'+str(c['chroma_rank']) if c.get('chroma_rank') else '—'}"
425
+ for c in candidates[:10]
426
+ ),
427
+ },
428
+ {
429
+ "step": "6. FusionNode Hallucination Guard",
430
+ "color": "#f59e0b",
431
+ "summary": (
432
+ f"{len(passed_candidates)} passed, {len(flagged_candidates)} flagged "
433
+ f"(min graph matches=2, min vector sim=0.65, require frequent HPO=True)"
434
+ ),
435
+ "detail": (
436
+ "**Flagged candidates:**\n" + "\n".join(
437
+ f"- {c['name']} — {c.get('flag_reason','?')}"
438
+ for c in flagged_candidates
439
+ ) if flagged_candidates else "No candidates flagged."
440
+ ),
441
+ },
442
+ {
443
+ "step": "7. Final Output",
444
+ "color": "#ec4899",
445
+ "summary": (
446
+ f"Top diagnosis: **{top.get('name','?')}** (ORPHA:{top.get('orpha_code','?')}) | "
447
+ f"Evidence: {top.get('evidence_score',0):.3f}"
448
+ ),
449
+ "detail": (
450
+ f"- Rank: #{top.get('rank','?')}\n"
451
+ f"- RRF score: {top.get('rrf_score',0):.5f}\n"
452
+ f"- Graph rank: #{top.get('graph_rank','—')}\n"
453
+ f"- Vector rank: #{top.get('chroma_rank','—')}\n"
454
+ f"- Guard: {'PASSED' if not top.get('hallucination_flag') else 'FLAGGED (fallback)'}\n"
455
+ f"- Elapsed: {result.get('elapsed_seconds','?')}s"
456
+ ),
457
+ },
458
+ ]
459
+
460
+ for s in steps:
461
+ with st.expander(f"{s['step']} — {s['summary']}", expanded=False):
462
+ st.markdown(
463
+ f'<div style="border-left: 3px solid {s["color"]}; padding-left: 12px;">',
464
+ unsafe_allow_html=True,
465
+ )
466
+ st.markdown(s["detail"])
467
+ st.markdown("</div>", unsafe_allow_html=True)
468
+
469
+ # Raw JSON
470
+ with st.expander("Raw API Response (JSON)"):
471
+ display = {k: v for k, v in result.items() if k != "_backend"}
472
+ st.json(display)
backend/dashboard/charts.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ charts.py — Plotly chart builders for the RareDx dashboard.
3
+
4
+ Keeps all visualisation logic out of app.py.
5
+ """
6
+
7
+ from __future__ import annotations
8
+ import plotly.graph_objects as go
9
+ import math
10
+
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # Colour palette
14
+ # ---------------------------------------------------------------------------
15
+ COL_PASS = "#22c55e" # green-500
16
+ COL_FLAG = "#f59e0b" # amber-500
17
+ COL_HPO = "#6366f1" # indigo-500
18
+ COL_DISEASE = "#0ea5e9" # sky-500
19
+ COL_TOP = "#ec4899" # pink-500 — top diagnosis
20
+ COL_MARFAN = "#f97316" # orange-500 — Marfan highlight
21
+ COL_EDGE = "#94a3b8" # slate-400
22
+ FREQ_COLORS = { # edge colour by frequency
23
+ "very frequent": "#22c55e",
24
+ "frequent": "#84cc16",
25
+ "occasional": "#f59e0b",
26
+ "rare": "#ef4444",
27
+ }
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Tab 1 — Evidence bar chart
32
+ # ---------------------------------------------------------------------------
33
+
34
+ def evidence_bar(candidates: list[dict]) -> go.Figure:
35
+ """
36
+ Horizontal bar chart of evidence scores, coloured by guard result.
37
+ Shows top 10 candidates.
38
+ """
39
+ top = candidates[:10]
40
+ names = [f"ORPHA:{c['orpha_code']} {c['name'][:35]}" for c in reversed(top)]
41
+ scores = [c.get("evidence_score", 0) for c in reversed(top)]
42
+ colors = [
43
+ COL_FLAG if c.get("hallucination_flag") else COL_PASS
44
+ for c in reversed(top)
45
+ ]
46
+ hover = []
47
+ for c in reversed(top):
48
+ flag_reason = c.get("flag_reason") or ""
49
+ flag_html = (
50
+ "<span style='color:orange'>FLAGGED: " + flag_reason + "</span>"
51
+ if c.get("hallucination_flag")
52
+ else "PASSED"
53
+ )
54
+ hover.append(
55
+ f"<b>{c['name']}</b><br>"
56
+ f"ORPHA:{c['orpha_code']}<br>"
57
+ f"Evidence: {c.get('evidence_score',0):.3f}<br>"
58
+ f"Graph rank: #{c.get('graph_rank','—')}<br>"
59
+ f"Vector rank: #{c.get('chroma_rank','—')}<br>"
60
+ f"Graph matches: {c.get('graph_matches','—')}<br>"
61
+ f"{flag_html}"
62
+ )
63
+
64
+ fig = go.Figure(go.Bar(
65
+ x=scores, y=names,
66
+ orientation="h",
67
+ marker_color=colors,
68
+ hovertext=hover,
69
+ hoverinfo="text",
70
+ text=[f"{s:.2f}" for s in scores],
71
+ textposition="outside",
72
+ ))
73
+ fig.update_layout(
74
+ title="Candidate Evidence Scores",
75
+ xaxis=dict(title="Evidence Score (0–1)", range=[0, 1.15]),
76
+ yaxis=dict(automargin=True),
77
+ height=420,
78
+ margin=dict(l=10, r=10, t=40, b=10),
79
+ paper_bgcolor="rgba(0,0,0,0)",
80
+ plot_bgcolor="rgba(0,0,0,0)",
81
+ font=dict(size=12),
82
+ )
83
+ return fig
84
+
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Tab 2 — Evidence map (bipartite: HPO terms → Diseases)
88
+ # ---------------------------------------------------------------------------
89
+
90
+ def evidence_map(result: dict) -> go.Figure:
91
+ """
92
+ Bipartite graph: extracted HPO terms on the left, disease candidates on
93
+ the right. Edges represent MANIFESTS_AS relationships.
94
+ Top diagnosis is highlighted in pink; flagged candidates in amber.
95
+ """
96
+ hpo_matches = result.get("hpo_matches", [])
97
+ candidates = result.get("candidates", [])[:10]
98
+ top_code = str(result.get("top_diagnosis", {}).get("orpha_code", ""))
99
+
100
+ # ---- Nodes ----
101
+ hpo_count = len(hpo_matches)
102
+ dis_count = len(candidates)
103
+
104
+ # Vertical spacing
105
+ def y_pos(i: int, total: int) -> float:
106
+ if total == 1:
107
+ return 0.5
108
+ return 1.0 - i / (total - 1)
109
+
110
+ node_x, node_y, node_text, node_color, node_size, node_hover = [], [], [], [], [], []
111
+
112
+ # HPO nodes (x=0.1)
113
+ hpo_id_to_y: dict[str, float] = {}
114
+ for i, m in enumerate(hpo_matches):
115
+ yy = y_pos(i, max(hpo_count, 1))
116
+ hpo_id_to_y[m["hpo_id"]] = yy
117
+ node_x.append(0.05)
118
+ node_y.append(yy)
119
+ node_text.append(m["term"][:30])
120
+ node_color.append(COL_HPO)
121
+ node_size.append(18)
122
+ phrase_quoted = '"' + m["phrase"] + '"'
123
+ node_hover.append(
124
+ f"<b>{m['term']}</b><br>{m['hpo_id']}<br>"
125
+ f"Score: {m['score']:.4f}<br>Phrase: {phrase_quoted}"
126
+ )
127
+
128
+ # Disease nodes (x=0.9)
129
+ dis_code_to_y: dict[str, float] = {}
130
+ for i, c in enumerate(candidates):
131
+ yy = y_pos(i, max(dis_count, 1))
132
+ dis_code_to_y[str(c["orpha_code"])] = yy
133
+ code = str(c["orpha_code"])
134
+ if code == top_code:
135
+ color = COL_TOP
136
+ size = 22
137
+ elif c.get("hallucination_flag"):
138
+ color = COL_FLAG
139
+ size = 16
140
+ else:
141
+ color = COL_DISEASE
142
+ size = 18
143
+ node_x.append(0.95)
144
+ node_y.append(yy)
145
+ node_text.append(c["name"][:28])
146
+ node_color.append(color)
147
+ node_size.append(size)
148
+ flag_note = f"<br><span style='color:orange'>FLAGGED</span>" if c.get("hallucination_flag") else ""
149
+ node_hover.append(
150
+ f"<b>{c['name']}</b><br>ORPHA:{code}<br>"
151
+ f"RRF: {c['rrf_score']:.5f}<br>Evidence: {c.get('evidence_score',0):.3f}{flag_note}"
152
+ )
153
+
154
+ # ---- Edges from matched_hpo on each candidate ----
155
+ edge_x, edge_y, edge_colors = [], [], []
156
+ for c in candidates:
157
+ code = str(c["orpha_code"])
158
+ if code not in dis_code_to_y:
159
+ continue
160
+ dy = dis_code_to_y[code]
161
+ for h in c.get("matched_hpo", []):
162
+ hid = h.get("hpo_id", "")
163
+ if hid not in hpo_id_to_y:
164
+ continue
165
+ hy = hpo_id_to_y[hid]
166
+ freq_label = h.get("frequency_label", "").lower()
167
+ ecol = next(
168
+ (v for k, v in FREQ_COLORS.items() if k in freq_label),
169
+ COL_EDGE,
170
+ )
171
+ edge_x += [0.05, 0.95, None]
172
+ edge_y += [hy, dy, None]
173
+ edge_colors.append(ecol)
174
+
175
+ fig = go.Figure()
176
+
177
+ # Draw edges (one trace per edge for colour — group by frequency colour)
178
+ from itertools import groupby
179
+ segments: list[tuple[str, list]] = []
180
+ i = 0
181
+ ex_chunks = [edge_x[j:j+3] for j in range(0, len(edge_x), 3)]
182
+ ey_chunks = [edge_y[j:j+3] for j in range(0, len(edge_y), 3)]
183
+ for col, ex, ey in zip(edge_colors, ex_chunks, ey_chunks):
184
+ fig.add_trace(go.Scatter(
185
+ x=ex, y=ey, mode="lines",
186
+ line=dict(color=col, width=1.5),
187
+ hoverinfo="none", showlegend=False,
188
+ ))
189
+
190
+ # Draw nodes
191
+ fig.add_trace(go.Scatter(
192
+ x=node_x, y=node_y,
193
+ mode="markers+text",
194
+ marker=dict(color=node_color, size=node_size, line=dict(width=1, color="#1e293b")),
195
+ text=node_text,
196
+ textposition=["middle left" if x < 0.5 else "middle right" for x in node_x],
197
+ hovertext=node_hover,
198
+ hoverinfo="text",
199
+ showlegend=False,
200
+ ))
201
+
202
+ # Column labels
203
+ fig.add_annotation(x=0.05, y=1.06, text="<b>HPO Terms</b>", showarrow=False,
204
+ font=dict(size=13, color=COL_HPO), xref="paper", yref="paper")
205
+ fig.add_annotation(x=0.95, y=1.06, text="<b>Disease Candidates</b>", showarrow=False,
206
+ font=dict(size=13, color=COL_DISEASE), xref="paper", yref="paper")
207
+
208
+ fig.update_layout(
209
+ title="Evidence Map — HPO Phenotypes to Disease Candidates",
210
+ xaxis=dict(visible=False, range=[-0.1, 1.1]),
211
+ yaxis=dict(visible=False, range=[-0.1, 1.1]),
212
+ height=520,
213
+ margin=dict(l=20, r=20, t=60, b=20),
214
+ paper_bgcolor="rgba(0,0,0,0)",
215
+ plot_bgcolor="rgba(0,0,0,0)",
216
+ )
217
+ return fig
218
+
219
+
220
+ # ---------------------------------------------------------------------------
221
+ # Tab 2 — Guard decision donut
222
+ # ---------------------------------------------------------------------------
223
+
224
+ def guard_donut(passed: int, flagged: int) -> go.Figure:
225
+ fig = go.Figure(go.Pie(
226
+ labels=["Passed", "Flagged"],
227
+ values=[passed, flagged],
228
+ hole=0.6,
229
+ marker_colors=[COL_PASS, COL_FLAG],
230
+ textinfo="label+value",
231
+ ))
232
+ fig.update_layout(
233
+ title="Hallucination Guard",
234
+ height=260,
235
+ margin=dict(l=10, r=10, t=40, b=10),
236
+ paper_bgcolor="rgba(0,0,0,0)",
237
+ showlegend=False,
238
+ )
239
+ return fig
240
+
241
+
242
+ # ---------------------------------------------------------------------------
243
+ # Tab 2 — RRF waterfall (which signal contributed what)
244
+ # ---------------------------------------------------------------------------
245
+
246
+ def rrf_waterfall(candidates: list[dict]) -> go.Figure:
247
+ top = candidates[:8]
248
+ names = [c["name"][:30] for c in top]
249
+
250
+ graph_contrib = [1/(60 + c["graph_rank"]) if c.get("graph_rank") else 0 for c in top]
251
+ chroma_contrib = [1/(60 + c["chroma_rank"]) if c.get("chroma_rank") else 0 for c in top]
252
+
253
+ fig = go.Figure()
254
+ fig.add_trace(go.Bar(name="Graph (phenotype)", x=names, y=graph_contrib,
255
+ marker_color=COL_HPO))
256
+ fig.add_trace(go.Bar(name="Vector (BioLORD)", x=names, y=chroma_contrib,
257
+ marker_color=COL_DISEASE))
258
+ fig.update_layout(
259
+ barmode="stack",
260
+ title="RRF Score Breakdown — Graph vs Vector Contribution",
261
+ yaxis_title="RRF contribution",
262
+ height=320,
263
+ margin=dict(l=10, r=10, t=40, b=80),
264
+ paper_bgcolor="rgba(0,0,0,0)",
265
+ plot_bgcolor="rgba(0,0,0,0)",
266
+ legend=dict(orientation="h", y=-0.25),
267
+ xaxis_tickangle=-30,
268
+ )
269
+ return fig
backend/reports/week4_evaluation.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RareDx — Week 4 Evaluation Report (RareBench-RAMEDIS)
2
+
3
+ **Generated:** 2026-03-17 20:03
4
+ **Pipeline:** DiagnosisPipeline v3.1 (BioLORD-2023 + LocalGraphStore + FusionNode)
5
+ **Evaluation set:** 30 cases sampled from [RareBench-RAMEDIS](https://huggingface.co/datasets/chenxz/RareBench) (624 total cases, 74 rare diseases)
6
+ **Case format:** HPO term names → ORPHA ground-truth code
7
+ **Source:** Feng et al. (2023), ACM KDD 2024 — real clinician-recorded phenotypes
8
+ **Threshold:** 0.50 | **Top-N:** 10
9
+
10
+ ---
11
+
12
+ ## Results
13
+
14
+ | Metric | Value | Hits / Total | Visual |
15
+ |--------|-------|-------------|--------|
16
+ | Recall@1 | **6.7%** | 2/30 | `█░░░░░░░░░░░░░░░░░░░` |
17
+ | Recall@3 | **16.7%** | 5/30 | `███░░░░░░░░░░░░░░░░░` |
18
+ | Recall@5 | **23.3%** | 7/30 | `█████░░░░░░░░░░░░░░░` |
19
+
20
+ ---
21
+
22
+ ## Benchmark Comparison
23
+
24
+ > **Comparison note:** DeepRare and baselines were evaluated on all 382–624 RAMEDIS cases using gene + variant data in addition to phenotype, giving them a significant advantage. RareDx uses phenotype-only input. This run uses 30 randomly sampled cases; results may vary vs. full-set evaluation.
25
+
26
+ > DeepRare, LIRICAL, Phrank, AMELIE, Phenomizer: Feng et al. (2023), RAMEDIS dataset (382 cases).
27
+
28
+ | System | Recall@1 | Recall@3 | Recall@5 |
29
+ |--------|----------|----------|----------|
30
+ | **RareDx (ours)** | **6.7%** | **16.7%** | **23.3%** |
31
+ | DeepRare | 37.0% | 54.0% | 62.0% |
32
+ | LIRICAL | 29.0% | 46.0% | 54.0% |
33
+ | Phrank | 22.0% | 38.0% | 47.0% |
34
+ | AMELIE | 19.0% | 33.0% | 41.0% |
35
+ | Phenomizer | 14.0% | 25.0% | 33.0% |
36
+ ### vs LIRICAL (closest phenotype-only baseline)
37
+
38
+ - Recall@1: behind by **22.3 pp** (-22.3)
39
+ - Recall@5: behind by **30.7 pp** (-30.7)
40
+
41
+ ---
42
+
43
+ ## Per-Case Results
44
+
45
+ | # | ORPHA | Disease | @1 | @3 | @5 | Rank | Top Prediction |
46
+ |---|-------|---------|----|----|----|----|----------------|
47
+ | 1 | 42 | 中链酰基辅酶 A 脱氢酶缺乏症 | | | | — | Fetal Gaucher disease |
48
+ | 2 | 27 | Vitamin B12-unresponsive methylmalo | | | | — | Malonic aciduria |
49
+ | 3 | 247598 | Neonatal intrahepatic cholestasis d | | | | — | Biotinidase deficiency |
50
+ | 4 | 89936 | X-linked hypophosphatemia | ✓ | ✓ | ✓ | 1 | X-linked hypophosphatemia |
51
+ | 5 | 67048 | 3-methylglutaconic aciduria type 4 | | | | — | X-linked neurodegenerative syn |
52
+ | 6 | 20 | 3-hydroxy-3-methylglutaric aciduria | | | ✓ | 4 | Pyruvate carboxylase deficienc |
53
+ | 7 | 79241 | 生物素酶缺乏症 | | | | — | Ichthyosis follicularis-alopec |
54
+ | 8 | 67046 | 3-methylglutaconic aciduria type 1 | | | | — | Bilateral polymicrogyria |
55
+ | 9 | 79318 | PMM2-CDG | | | | — | Alpha-mannosidosis, infantile |
56
+ | 10 | 35 | 丙酸血症 | | | | — | Argininosuccinic aciduria |
57
+ | 11 | 90 | 精氨酸酶缺乏症 | | | | 6 | Lysinuric protein intolerance |
58
+ | 12 | 664 | 鸟氨酸氨甲酰胺基转移酶缺乏症 | | ✓ | ✓ | 3 | Lysinuric protein intolerance |
59
+ | 13 | 552 | MODY | | | | — | Adenine phosphoribosyltransfer |
60
+ | 14 | 247525 | Citrullinemia type I | | | | 6 | Ornithine transcarbamylase def |
61
+ | 15 | 79242 | 全羧化酶合成酶缺乏症 | | | | — | Pyruvate carboxylase deficienc |
62
+ | 16 | 348 | Fructose-1,6-bisphosphatase deficie | | | | 8 | Pyruvate dehydrogenase E3 defi |
63
+ | 17 | 414 | Gyrate atrophy of choroid and retin | | | | — | Lysinuric protein intolerance |
64
+ | 18 | 30 | Hereditary orotic aciduria | | | | — | Brucellosis |
65
+ | 19 | 56 | Alkaptonuria | | | | — | Hurler syndrome |
66
+ | 20 | 147 | Carbamoyl-phosphate synthetase 1 de | | | | — | Hyperornithinemia-hyperammonem |
67
+ | 21 | 90791 | Congenital adrenal hyperplasia due | | ✓ | ✓ | 3 | Classic congenital adrenal hyp |
68
+ | 22 | 716 | 苯丙酮尿症 | | | | — | Lynch syndrome |
69
+ | 23 | 79282 | Methylmalonic acidemia with homocys | | ✓ | ✓ | 3 | Isolated ATP synthase deficien |
70
+ | 24 | 158 | 原发性肉碱缺乏症 | | | | — | Lysinuric protein intolerance |
71
+ | 25 | 23 | Argininosuccinic aciduria | ✓ | ✓ | ✓ | 1 | Argininosuccinic aciduria |
72
+ | 26 | 79276 | Acute intermittent porphyria | | | | — | Neurocutaneous melanocytosis |
73
+ | 27 | 33 | 异戊酸血症 | | | | — | Marburg hemorrhagic fever |
74
+ | 28 | 22 | Succinic semialdehyde dehydrogenase | | | | — | Congenital multicore myopathy |
75
+ | 29 | 699 | Pearson综合征 | | | | — | Lysinuric protein intolerance |
76
+ | 30 | 580 | Mucopolysaccharidosis type 2 | | | ✓ | 5 | Simpson-Golabi-Behmel syndrome |
77
+
78
+ ---
79
+
80
+ ### Missed Cases (not in top 5)
81
+
82
+ - **ORPHA:42** 中链酰基辅酶 A 脱氢酶缺乏症 → predicted: *Fetal Gaucher disease*
83
+ - **ORPHA:27** Vitamin B12-unresponsive methylmalonic acidemia → predicted: *Malonic aciduria*
84
+ - **ORPHA:247598** Neonatal intrahepatic cholestasis due to citrin deficiency → predicted: *Biotinidase deficiency*
85
+ - **ORPHA:67048** 3-methylglutaconic aciduria type 4 → predicted: *X-linked neurodegenerative syndrome, Hamel type*
86
+ - **ORPHA:79241** 生物素酶缺乏症 → predicted: *Ichthyosis follicularis-alopecia-photophobia syndrome*
87
+ - **ORPHA:67046** 3-methylglutaconic aciduria type 1 → predicted: *Bilateral polymicrogyria*
88
+ - **ORPHA:79318** PMM2-CDG → predicted: *Alpha-mannosidosis, infantile form*
89
+ - **ORPHA:35** 丙酸血症 → predicted: *Argininosuccinic aciduria*
90
+ - **ORPHA:90** 精氨酸酶缺乏症 → predicted: *Lysinuric protein intolerance*
91
+ - **ORPHA:552** MODY → predicted: *Adenine phosphoribosyltransferase deficiency*
92
+ - **ORPHA:247525** Citrullinemia type I → predicted: *Ornithine transcarbamylase deficiency*
93
+ - **ORPHA:79242** 全羧化酶合成酶缺乏症 → predicted: *Pyruvate carboxylase deficiency*
94
+ - **ORPHA:348** Fructose-1,6-bisphosphatase deficiency → predicted: *Pyruvate dehydrogenase E3 deficiency*
95
+ - **ORPHA:414** Gyrate atrophy of choroid and retina → predicted: *Lysinuric protein intolerance*
96
+ - **ORPHA:30** Hereditary orotic aciduria → predicted: *Brucellosis*
97
+
98
+ ---
99
+ ## Pipeline Configuration
100
+
101
+ | Component | Detail |
102
+ |-----------|--------|
103
+ | Embedding model | FremyCompany/BioLORD-2023 (768-dim) |
104
+ | HPO index | 8,701 terms |
105
+ | Graph store | LocalGraphStore — 11,456 diseases, 115,839 MANIFESTS_AS edges |
106
+ | ChromaDB | Persistent embedded (HPO-enriched embeddings) |
107
+ | Symptom parser threshold | 0.55 (multi-word), 0.82 (single-word) |
108
+ | RRF K | 60 |
109
+ | Hallucination guard | FusionNode (min_graph=2, min_sim=0.65, require_frequent=True) |
110
+
111
+ ---
112
+
113
+ ## Methodology
114
+
115
+ **RareBench-RAMEDIS methodology:**
116
+ Each case provides a list of HPO term IDs representing a real patient's documented phenotype.
117
+ Ground truth is the corresponding Orphanet disease code.
118
+
119
+ Clinical notes were built by resolving HP IDs to human-readable term names via the
120
+ RareBench phenotype mapping (https://raw.githubusercontent.com/chenxz1111/RareBench/main/mapping/phenotype_mapping.json).
121
+ The pipeline ingests these term names exactly as it would a free-text clinical note.
122
+
123
+ **Limitations:**
124
+ - 30 of 624 RAMEDIS cases used (random sample, seed=42)
125
+ - HP term names are the *only* input — no free-text narrative context
126
+ - DeepRare baselines use gene panel + phenotype; direct Recall@k comparison is indicative
127
+ - Full-set evaluation on all 624 cases is future work
128
+
129
+ ---
130
+
131
+ *Generated by week4_evaluation.py — RareDx Week 4*
backend/requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Graph DB (connects to Docker Neo4j when available)
2
+ neo4j==5.17.0
3
+
4
+ # Vector DB (embedded persistent mode, no server required)
5
+ chromadb==0.5.0
6
+
7
+ # Embeddings - BioLORD-2023
8
+ sentence-transformers==3.0.1
9
+ torch==2.3.1
10
+ transformers==4.41.2
11
+
12
+ # Data / parsing
13
+ lxml==5.2.2
14
+ requests==2.32.3
15
+ tqdm==4.66.4
16
+
17
+ # Config
18
+ python-dotenv==1.0.1
19
+
20
+ # Local graph fallback (when Neo4j Docker not available)
21
+ networkx==3.3
22
+
23
+ # API
24
+ fastapi==0.111.0
25
+ uvicorn[standard]==0.30.1
26
+
27
+ # Dashboard
28
+ streamlit==1.35.0
29
+ plotly==5.22.0
backend/scripts/__pycache__/graph_store.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
backend/scripts/__pycache__/symptom_parser.cpython-310.pyc ADDED
Binary file (7.22 kB). View file
 
backend/scripts/download_hpo.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ download_hpo.py
3
+ ---------------
4
+ Downloads Orphanet product 4 XML: disease-to-HPO phenotype associations.
5
+
6
+ en_product4.xml maps each OrphaCode to HPO terms with:
7
+ - HPO ID (e.g. HP:0001166)
8
+ - HPO term name (e.g. "Arachnodactyly")
9
+ - Frequency (Very frequent / Frequent / Occasional / Rare / Excluded)
10
+ - Diagnostic criteria (Pathognomonic / Diagnostic / Major / Minor)
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import requests
16
+ from pathlib import Path
17
+ from dotenv import load_dotenv
18
+
19
+ load_dotenv(Path(__file__).parents[2] / ".env")
20
+
21
+ DATA_DIR = Path(os.getenv("ORPHANET_DATA_DIR", "./data/orphanet"))
22
+ OUTPUT_FILE = DATA_DIR / "en_product4.xml"
23
+
24
+ URLS = [
25
+ "https://www.orphadata.com/data/xml/en_product4.xml",
26
+ "http://www.orphadata.org/data/xml/en_product4.xml",
27
+ ]
28
+
29
+
30
+ def download(urls: list[str], output: Path, timeout: int = 120) -> bool:
31
+ headers = {"User-Agent": "RareDxBot/1.0"}
32
+ for url in urls:
33
+ print(f" Trying: {url}")
34
+ try:
35
+ r = requests.get(url, headers=headers, timeout=timeout, stream=True)
36
+ r.raise_for_status()
37
+ total = int(r.headers.get("content-length", 0))
38
+ downloaded = 0
39
+ with open(output, "wb") as f:
40
+ for chunk in r.iter_content(65536):
41
+ f.write(chunk)
42
+ downloaded += len(chunk)
43
+ if total:
44
+ print(f"\r {downloaded:,} / {total:,} bytes ({downloaded/total*100:.1f}%)", end="")
45
+ print(f"\n Saved: {output} ({output.stat().st_size:,} bytes)")
46
+ return True
47
+ except Exception as e:
48
+ print(f" Failed: {e}")
49
+ return False
50
+
51
+
52
+ def main() -> None:
53
+ print("=" * 60)
54
+ print("RareDx — Week 2A: Download HPO Phenotype Data")
55
+ print("=" * 60)
56
+
57
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
58
+
59
+ if OUTPUT_FILE.exists() and OUTPUT_FILE.stat().st_size > 10_000:
60
+ print(f"Already exists: {OUTPUT_FILE} ({OUTPUT_FILE.stat().st_size:,} bytes). Skipping.")
61
+ return
62
+
63
+ print("\nDownloading en_product4.xml (HPO associations)...")
64
+ if not download(URLS, OUTPUT_FILE):
65
+ print("ERROR: Could not download en_product4.xml.")
66
+ sys.exit(1)
67
+
68
+ content = OUTPUT_FILE.read_text(encoding="utf-8")
69
+ count = content.count("<HPOId>")
70
+ print(f"Validation OK — {count:,} HPO associations found.")
71
+ print("\nStep done.")
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()
backend/scripts/download_orphanet.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ download_orphanet.py
3
+ -------------------
4
+ Downloads Orphanet product 1 XML (rare disease names + synonyms).
5
+ Falls back to embedded sample data if the remote file is unreachable.
6
+
7
+ Orphadata en_product1.xml contains:
8
+ - OrphaCode
9
+ - Disease name (English)
10
+ - Synonyms
11
+ - ExternalReference (OMIM, UMLS, etc.)
12
+ """
13
+
14
+ import os
15
+ import sys
16
+ import requests
17
+ from pathlib import Path
18
+ from dotenv import load_dotenv
19
+
20
+ load_dotenv(Path(__file__).parents[2] / ".env")
21
+
22
+ DATA_DIR = Path(os.getenv("ORPHANET_DATA_DIR", "./data/orphanet"))
23
+ OUTPUT_FILE = DATA_DIR / "en_product1.xml"
24
+
25
+ # Orphadata direct download URL (public, no auth required)
26
+ ORPHADATA_URL = (
27
+ "https://www.orphadata.com/data/xml/en_product1.xml"
28
+ )
29
+
30
+ FALLBACK_URLS = [
31
+ "https://www.orphadata.com/data/xml/en_product1.xml",
32
+ "http://www.orphadata.org/data/xml/en_product1.xml",
33
+ ]
34
+
35
+ SAMPLE_XML = """\
36
+ <?xml version="1.0" encoding="UTF-8"?>
37
+ <JDBOR date="2024-01-01" version="1.3.26.5" copyright="Orphanet (c) 2024">
38
+ <DisorderList count="10">
39
+ <Disorder id="166">
40
+ <OrphaCode>166</OrphaCode>
41
+ <ExpertLink lang="en">http://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=EN&amp;Expert=166</ExpertLink>
42
+ <Name lang="en">Marfan syndrome</Name>
43
+ <DisorderType id="21394"><Name lang="en">Disease</Name></DisorderType>
44
+ <DisorderGroup id="36547"><Name lang="en">Disorder</Name></DisorderGroup>
45
+ <TextAuto lang="en">Marfan syndrome is a systemic connective tissue disorder caused by mutations in the FBN1 gene encoding fibrillin-1. It affects the cardiovascular system (aortic dilatation, mitral valve prolapse), skeleton (tall stature, arachnodactyly, scoliosis), and ocular system (ectopia lentis).</TextAuto>
46
+ <SynonymList count="2">
47
+ <Synonym lang="en">MFS</Synonym>
48
+ <Synonym lang="en">Marfan's syndrome</Synonym>
49
+ </SynonymList>
50
+ </Disorder>
51
+ <Disorder id="79318">
52
+ <OrphaCode>79318</OrphaCode>
53
+ <ExpertLink lang="en">http://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=EN&amp;Expert=79318</ExpertLink>
54
+ <Name lang="en">Ehlers-Danlos syndrome, hypermobile type</Name>
55
+ <DisorderType id="21394"><Name lang="en">Disease</Name></DisorderType>
56
+ <DisorderGroup id="36547"><Name lang="en">Disorder</Name></DisorderGroup>
57
+ <TextAuto lang="en">Hypermobile Ehlers-Danlos syndrome (hEDS) is a connective tissue disorder characterized by joint hypermobility, skin hyperextensibility, and musculoskeletal fragility. It is the most common form of EDS and lacks a known causative gene.</TextAuto>
58
+ <SynonymList count="2">
59
+ <Synonym lang="en">Hypermobile EDS</Synonym>
60
+ <Synonym lang="en">hEDS</Synonym>
61
+ </SynonymList>
62
+ </Disorder>
63
+ <Disorder id="93">
64
+ <OrphaCode>93</OrphaCode>
65
+ <ExpertLink lang="en">http://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=EN&amp;Expert=93</ExpertLink>
66
+ <Name lang="en">Wilson disease</Name>
67
+ <DisorderType id="21394"><Name lang="en">Disease</Name></DisorderType>
68
+ <DisorderGroup id="36547"><Name lang="en">Disorder</Name></DisorderGroup>
69
+ <TextAuto lang="en">Wilson disease is an autosomal recessive disorder of copper metabolism caused by mutations in the ATP7B gene. It leads to copper accumulation in the liver, brain, and other organs, causing hepatic cirrhosis, neuropsychiatric symptoms, and Kayser-Fleischer rings.</TextAuto>
70
+ <SynonymList count="2">
71
+ <Synonym lang="en">Hepatolenticular degeneration</Synonym>
72
+ <Synonym lang="en">WD</Synonym>
73
+ </SynonymList>
74
+ </Disorder>
75
+ <Disorder id="2552">
76
+ <OrphaCode>2552</OrphaCode>
77
+ <ExpertLink lang="en">http://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=EN&amp;Expert=2552</ExpertLink>
78
+ <Name lang="en">Huntington disease</Name>
79
+ <DisorderType id="21394"><Name lang="en">Disease</Name></DisorderType>
80
+ <DisorderGroup id="36547"><Name lang="en">Disorder</Name></DisorderGroup>
81
+ <TextAuto lang="en">Huntington disease is an autosomal dominant neurodegenerative disorder caused by a CAG repeat expansion in the HTT gene. It manifests as progressive chorea, cognitive decline, and psychiatric disturbances, with adult onset typically between 30-50 years.</TextAuto>
82
+ <SynonymList count="2">
83
+ <Synonym lang="en">HD</Synonym>
84
+ <Synonym lang="en">Huntington's chorea</Synonym>
85
+ </SynonymList>
86
+ </Disorder>
87
+ <Disorder id="586">
88
+ <OrphaCode>586</OrphaCode>
89
+ <ExpertLink lang="en">http://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=EN&amp;Expert=586</ExpertLink>
90
+ <Name lang="en">Cystic fibrosis</Name>
91
+ <DisorderType id="21394"><Name lang="en">Disease</Name></DisorderType>
92
+ <DisorderGroup id="36547"><Name lang="en">Disorder</Name></DisorderGroup>
93
+ <TextAuto lang="en">Cystic fibrosis is an autosomal recessive multisystem disorder caused by mutations in the CFTR gene, leading to defective chloride transport. It primarily affects the lungs (progressive obstructive lung disease), pancreas (exocrine insufficiency), and reproductive system.</TextAuto>
94
+ <SynonymList count="2">
95
+ <Synonym lang="en">CF</Synonym>
96
+ <Synonym lang="en">Mucoviscidosis</Synonym>
97
+ </SynonymList>
98
+ </Disorder>
99
+ <Disorder id="774">
100
+ <OrphaCode>774</OrphaCode>
101
+ <ExpertLink lang="en">http://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=EN&amp;Expert=774</ExpertLink>
102
+ <Name lang="en">Phenylketonuria</Name>
103
+ <DisorderType id="21394"><Name lang="en">Disease</Name></DisorderType>
104
+ <DisorderGroup id="36547"><Name lang="en">Disorder</Name></DisorderGroup>
105
+ <TextAuto lang="en">Phenylketonuria (PKU) is an autosomal recessive inborn error of phenylalanine metabolism caused by mutations in the PAH gene encoding phenylalanine hydroxylase. It results in hyperphenylalaninemia leading to intellectual disability if untreated.</TextAuto>
106
+ <SynonymList count="1">
107
+ <Synonym lang="en">PKU</Synonym>
108
+ </SynonymList>
109
+ </Disorder>
110
+ <Disorder id="778">
111
+ <OrphaCode>778</OrphaCode>
112
+ <ExpertLink lang="en">http://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=EN&amp;Expert=778</ExpertLink>
113
+ <Name lang="en">Pompe disease</Name>
114
+ <DisorderType id="21394"><Name lang="en">Disease</Name></DisorderType>
115
+ <DisorderGroup id="36547"><Name lang="en">Disorder</Name></DisorderGroup>
116
+ <TextAuto lang="en">Pompe disease (glycogen storage disease type II) is an autosomal recessive lysosomal storage disorder caused by deficiency of acid alpha-glucosidase (GAA). It causes progressive muscle weakness, respiratory failure, and cardiomyopathy (in infantile form).</TextAuto>
117
+ <SynonymList count="2">
118
+ <Synonym lang="en">Glycogen storage disease type 2</Synonym>
119
+ <Synonym lang="en">Acid maltase deficiency</Synonym>
120
+ </SynonymList>
121
+ </Disorder>
122
+ <Disorder id="823">
123
+ <OrphaCode>823</OrphaCode>
124
+ <ExpertLink lang="en">http://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=EN&amp;Expert=823</ExpertLink>
125
+ <Name lang="en">Tuberous sclerosis complex</Name>
126
+ <DisorderType id="21394"><Name lang="en">Disease</Name></DisorderType>
127
+ <DisorderGroup id="36547"><Name lang="en">Disorder</Name></DisorderGroup>
128
+ <TextAuto lang="en">Tuberous sclerosis complex (TSC) is an autosomal dominant multisystem disorder caused by mutations in TSC1 or TSC2 genes, leading to mTOR pathway overactivation. It causes benign tumors in multiple organs including brain (cortical tubers, subependymal nodules), skin, kidneys, and lungs.</TextAuto>
129
+ <SynonymList count="2">
130
+ <Synonym lang="en">TSC</Synonym>
131
+ <Synonym lang="en">Bourneville disease</Synonym>
132
+ </SynonymList>
133
+ </Disorder>
134
+ <Disorder id="699">
135
+ <OrphaCode>699</OrphaCode>
136
+ <ExpertLink lang="en">http://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=EN&amp;Expert=699</ExpertLink>
137
+ <Name lang="en">Fabry disease</Name>
138
+ <DisorderType id="21394"><Name lang="en">Disease</Name></DisorderType>
139
+ <DisorderGroup id="36547"><Name lang="en">Disorder</Name></DisorderGroup>
140
+ <TextAuto lang="en">Fabry disease is an X-linked lysosomal storage disorder caused by deficiency of alpha-galactosidase A (GLA gene), leading to accumulation of globotriaosylceramide. It causes neuropathic pain, angiokeratomas, cardiomyopathy, renal failure, and stroke.</TextAuto>
141
+ <SynonymList count="2">
142
+ <Synonym lang="en">Anderson-Fabry disease</Synonym>
143
+ <Synonym lang="en">Alpha-galactosidase A deficiency</Synonym>
144
+ </SynonymList>
145
+ </Disorder>
146
+ <Disorder id="101">
147
+ <OrphaCode>101</OrphaCode>
148
+ <ExpertLink lang="en">http://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=EN&amp;Expert=101</ExpertLink>
149
+ <Name lang="en">Achondroplasia</Name>
150
+ <DisorderType id="21394"><Name lang="en">Disease</Name></DisorderType>
151
+ <DisorderGroup id="36547"><Name lang="en">Disorder</Name></DisorderGroup>
152
+ <TextAuto lang="en">Achondroplasia is the most common form of skeletal dysplasia, caused by gain-of-function mutations in the FGFR3 gene. It is characterized by short stature, rhizomelic limb shortening, macrocephaly with midface hypoplasia, and normal intelligence.</TextAuto>
153
+ <SynonymList count="0">
154
+ </SynonymList>
155
+ </Disorder>
156
+ </DisorderList>
157
+ </JDBOR>
158
+ """
159
+
160
+
161
+ def download_with_retry(urls: list[str], output_path: Path, timeout: int = 60) -> bool:
162
+ """Try each URL in order, return True if any succeeds."""
163
+ headers = {
164
+ "User-Agent": (
165
+ "Mozilla/5.0 (compatible; RareDxBot/1.0; "
166
+ "+https://github.com/rare-dx)"
167
+ )
168
+ }
169
+ for url in urls:
170
+ print(f" Trying: {url}")
171
+ try:
172
+ response = requests.get(url, headers=headers, timeout=timeout, stream=True)
173
+ response.raise_for_status()
174
+
175
+ total = int(response.headers.get("content-length", 0))
176
+ downloaded = 0
177
+
178
+ with open(output_path, "wb") as f:
179
+ for chunk in response.iter_content(chunk_size=65536):
180
+ f.write(chunk)
181
+ downloaded += len(chunk)
182
+ if total:
183
+ pct = downloaded / total * 100
184
+ print(f"\r Downloaded: {downloaded:,} / {total:,} bytes ({pct:.1f}%)", end="")
185
+
186
+ print(f"\n Saved to {output_path} ({output_path.stat().st_size:,} bytes)")
187
+ return True
188
+
189
+ except Exception as exc:
190
+ print(f" Failed: {exc}")
191
+
192
+ return False
193
+
194
+
195
+ def write_sample_data(output_path: Path) -> None:
196
+ """Write embedded sample Orphanet data for offline/dev use."""
197
+ print(" Using embedded sample data (10 diseases).")
198
+ output_path.write_text(SAMPLE_XML, encoding="utf-8")
199
+ print(f" Saved to {output_path} ({output_path.stat().st_size:,} bytes)")
200
+
201
+
202
+ def main() -> None:
203
+ print("=" * 60)
204
+ print("RareDx — Step 1: Download Orphanet Data")
205
+ print("=" * 60)
206
+
207
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
208
+
209
+ if OUTPUT_FILE.exists() and OUTPUT_FILE.stat().st_size > 1000:
210
+ print(f"Already exists: {OUTPUT_FILE} ({OUTPUT_FILE.stat().st_size:,} bytes). Skipping download.")
211
+ return
212
+
213
+ print("\nAttempting to download en_product1.xml from Orphadata...")
214
+ success = download_with_retry(FALLBACK_URLS, OUTPUT_FILE)
215
+
216
+ if not success:
217
+ print("\nRemote download failed. Writing sample data for development.")
218
+ write_sample_data(OUTPUT_FILE)
219
+
220
+ # Quick validation
221
+ content = OUTPUT_FILE.read_text(encoding="utf-8")
222
+ if "<OrphaCode>" not in content:
223
+ print("ERROR: Downloaded file does not appear to be valid Orphanet XML.")
224
+ sys.exit(1)
225
+
226
+ disease_count = content.count("<OrphaCode>")
227
+ print(f"\nValidation OK — found {disease_count} disease entries.")
228
+ print("\nStep 1 complete.")
229
+
230
+
231
+ if __name__ == "__main__":
232
+ main()
backend/scripts/embed_chromadb.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ embed_chromadb.py
3
+ -----------------
4
+ Generates BioLORD-2023 embeddings for each Orphanet disease and stores
5
+ them in ChromaDB.
6
+
7
+ Primary: ChromaDB HTTP client (Docker service at localhost:8000)
8
+ Fallback: ChromaDB PersistentClient (embedded, no server required)
9
+
10
+ Embedding text strategy:
11
+ "<name>. <definition>. Also known as: <syn1>, <syn2>, ..."
12
+ """
13
+
14
+ import os
15
+ import sys
16
+ from pathlib import Path
17
+ from lxml import etree
18
+ import chromadb
19
+ from chromadb.config import Settings
20
+ from sentence_transformers import SentenceTransformer
21
+ from dotenv import load_dotenv
22
+
23
+ load_dotenv(Path(__file__).parents[2] / ".env")
24
+
25
+ CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
26
+ CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000"))
27
+ COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases")
28
+ EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
29
+ XML_PATH = Path(os.getenv("ORPHANET_XML", "./data/orphanet/en_product1.xml"))
30
+
31
+ CHROMA_PERSIST_DIR = Path(__file__).parents[2] / "data" / "chromadb"
32
+ BATCH_SIZE = 32
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # XML parsing
37
+ # ---------------------------------------------------------------------------
38
+
39
+ def _text(element, xpath: str) -> str:
40
+ nodes = element.xpath(xpath)
41
+ if nodes:
42
+ val = nodes[0]
43
+ return (val.text or "").strip() if hasattr(val, "text") else str(val).strip()
44
+ return ""
45
+
46
+
47
+ def parse_disorders(xml_path: Path) -> list[dict]:
48
+ print(f"Parsing {xml_path} ...")
49
+ tree = etree.parse(str(xml_path))
50
+ root = tree.getroot()
51
+ disorders = []
52
+ for disorder in root.xpath("//Disorder"):
53
+ orpha_code = _text(disorder, "OrphaCode")
54
+ name = _text(disorder, "Name[@lang='en']")
55
+ definition = _text(disorder, "TextAuto[@lang='en']")
56
+ synonyms = [
57
+ s.text.strip()
58
+ for s in disorder.xpath("SynonymList/Synonym[@lang='en']")
59
+ if s.text and s.text.strip()
60
+ ]
61
+ if not orpha_code or not name:
62
+ continue
63
+
64
+ parts = [name]
65
+ if definition:
66
+ parts.append(definition)
67
+ if synonyms:
68
+ parts.append(f"Also known as: {', '.join(synonyms)}.")
69
+ embed_text = " ".join(parts)
70
+
71
+ disorders.append({
72
+ "id": f"ORPHA:{orpha_code}",
73
+ "orpha_code": orpha_code,
74
+ "name": name,
75
+ "definition": definition,
76
+ "synonyms": synonyms,
77
+ "embed_text": embed_text,
78
+ })
79
+
80
+ print(f" Parsed {len(disorders)} disorders.")
81
+ return disorders
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # ChromaDB client — HTTP first, persistent fallback
86
+ # ---------------------------------------------------------------------------
87
+
88
+ def get_chroma_client() -> tuple[chromadb.ClientAPI, str]:
89
+ """
90
+ Try HTTP client (Docker). On failure, fall back to embedded PersistentClient.
91
+ Returns (client, backend_label).
92
+ """
93
+ try:
94
+ client = chromadb.HttpClient(
95
+ host=CHROMA_HOST,
96
+ port=CHROMA_PORT,
97
+ settings=Settings(anonymized_telemetry=False),
98
+ )
99
+ client.heartbeat()
100
+ print(" ChromaDB HTTP server connected.")
101
+ return client, "ChromaDB HTTP (Docker)"
102
+ except Exception as exc:
103
+ print(f" ChromaDB HTTP not reachable ({exc}).")
104
+ print(f" Using embedded PersistentClient at {CHROMA_PERSIST_DIR}")
105
+ CHROMA_PERSIST_DIR.mkdir(parents=True, exist_ok=True)
106
+ client = chromadb.PersistentClient(
107
+ path=str(CHROMA_PERSIST_DIR),
108
+ settings=Settings(anonymized_telemetry=False),
109
+ )
110
+ return client, "ChromaDB Embedded (local)"
111
+
112
+
113
+ def get_or_create_collection(client: chromadb.ClientAPI, name: str) -> chromadb.Collection:
114
+ try:
115
+ client.delete_collection(name)
116
+ print(f" Deleted existing collection '{name}'.")
117
+ except Exception:
118
+ pass
119
+ collection = client.create_collection(
120
+ name=name,
121
+ metadata={"hnsw:space": "cosine"},
122
+ )
123
+ print(f" Created collection '{name}'.")
124
+ return collection
125
+
126
+
127
+ def upsert_in_batches(
128
+ collection: chromadb.Collection,
129
+ disorders: list[dict],
130
+ embeddings: list[list[float]],
131
+ ) -> None:
132
+ for i in range(0, len(disorders), BATCH_SIZE):
133
+ bd = disorders[i : i + BATCH_SIZE]
134
+ be = embeddings[i : i + BATCH_SIZE]
135
+ collection.upsert(
136
+ ids=[d["id"] for d in bd],
137
+ embeddings=be,
138
+ documents=[d["embed_text"] for d in bd],
139
+ metadatas=[
140
+ {
141
+ "orpha_code": d["orpha_code"],
142
+ "name": d["name"],
143
+ "definition": d["definition"][:500] if d["definition"] else "",
144
+ "synonyms": ", ".join(d["synonyms"]),
145
+ }
146
+ for d in bd
147
+ ],
148
+ )
149
+ print(f" Upserted {min(i + BATCH_SIZE, len(disorders))} / {len(disorders)} ...", end="\r")
150
+ print()
151
+
152
+
153
+ # ---------------------------------------------------------------------------
154
+ # Main
155
+ # ---------------------------------------------------------------------------
156
+
157
+ def main() -> None:
158
+ print("=" * 60)
159
+ print("RareDx — Step 3: Embed Diseases into ChromaDB (BioLORD-2023)")
160
+ print("=" * 60)
161
+
162
+ if not XML_PATH.exists():
163
+ print(f"ERROR: XML not found at {XML_PATH}. Run download_orphanet.py first.")
164
+ sys.exit(1)
165
+
166
+ disorders = parse_disorders(XML_PATH)
167
+
168
+ # Load BioLORD-2023
169
+ print(f"\nLoading embedding model: {EMBED_MODEL}")
170
+ print(" (First run will download ~440 MB from HuggingFace — please wait.)")
171
+ model = SentenceTransformer(EMBED_MODEL)
172
+ dim = model.get_sentence_embedding_dimension()
173
+ print(f" Model loaded. Embedding dim: {dim}")
174
+
175
+ # Generate embeddings
176
+ print(f"\nGenerating embeddings for {len(disorders)} diseases...")
177
+ texts = [d["embed_text"] for d in disorders]
178
+ embeddings = model.encode(
179
+ texts,
180
+ batch_size=BATCH_SIZE,
181
+ show_progress_bar=True,
182
+ normalize_embeddings=True,
183
+ )
184
+ print(f" Embeddings shape: {embeddings.shape}")
185
+
186
+ # Connect to ChromaDB
187
+ print("\nConnecting to ChromaDB...")
188
+ chroma, backend_label = get_chroma_client()
189
+ collection = get_or_create_collection(chroma, COLLECTION_NAME)
190
+
191
+ print(f"\nUpserting {len(disorders)} documents...")
192
+ upsert_in_batches(collection, disorders, embeddings.tolist())
193
+
194
+ final_count = collection.count()
195
+ print(f" Collection '{COLLECTION_NAME}' has {final_count} documents.")
196
+
197
+ # Sanity check
198
+ print("\nSanity check: semantic search for 'connective tissue disorder'")
199
+ probe = model.encode(["connective tissue disorder"], normalize_embeddings=True)
200
+ results = collection.query(query_embeddings=probe.tolist(), n_results=3)
201
+ for meta in results["metadatas"][0]:
202
+ print(f" -> [{meta['orpha_code']}] {meta['name']}")
203
+
204
+ print(f"\nStep 3 complete — backend: {backend_label}")
205
+
206
+
207
+ if __name__ == "__main__":
208
+ main()
backend/scripts/graph_store.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ graph_store.py
3
+ --------------
4
+ Lightweight local graph store that mirrors the Neo4j schema used by RareDx.
5
+ Uses NetworkX in-memory + JSON persistence as a drop-in fallback when
6
+ the Neo4j Docker service is unavailable.
7
+
8
+ Graph schema:
9
+ (:Disease {orpha_code, name, definition, expert_link})
10
+ (:Synonym {text})
11
+ (:HPOTerm {hpo_id, term})
12
+ (:Disease)-[:HAS_SYNONYM]->(:Synonym)
13
+ (:Disease)-[:MANIFESTS_AS {frequency, frequency_label, diagnostic_criteria}]->(:HPOTerm)
14
+ """
15
+
16
+ import json
17
+ from pathlib import Path
18
+ from typing import Optional
19
+
20
+ import networkx as nx
21
+
22
+ DEFAULT_PATH = Path(__file__).parents[2] / "data" / "graph_store.json"
23
+
24
+
25
+ class LocalGraphStore:
26
+ """NetworkX-backed graph store with JSON persistence."""
27
+
28
+ def __init__(self, path: Path = DEFAULT_PATH) -> None:
29
+ self.path = path
30
+ self.graph = nx.DiGraph()
31
+ if path.exists():
32
+ self._load()
33
+
34
+ # ------------------------------------------------------------------
35
+ # Persistence
36
+ # ------------------------------------------------------------------
37
+
38
+ def _load(self) -> None:
39
+ data = json.loads(self.path.read_text(encoding="utf-8"))
40
+ for node in data.get("nodes", []):
41
+ nid = node.pop("id")
42
+ self.graph.add_node(nid, **node)
43
+ for edge in data.get("edges", []):
44
+ attrs = {k: v for k, v in edge.items() if k not in ("src", "dst")}
45
+ self.graph.add_edge(edge["src"], edge["dst"], **attrs)
46
+
47
+ def save(self) -> None:
48
+ self.path.parent.mkdir(parents=True, exist_ok=True)
49
+ data = {
50
+ "nodes": [{"id": n, **self.graph.nodes[n]} for n in self.graph.nodes],
51
+ "edges": [
52
+ {"src": u, "dst": v, **d}
53
+ for u, v, d in self.graph.edges(data=True)
54
+ ],
55
+ }
56
+ self.path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
57
+
58
+ # ------------------------------------------------------------------
59
+ # Disease + Synonym write
60
+ # ------------------------------------------------------------------
61
+
62
+ def upsert_disease(self, orpha_code: int, name: str, definition: str, expert_link: str) -> None:
63
+ nid = f"Disease:{orpha_code}"
64
+ self.graph.add_node(
65
+ nid,
66
+ type="Disease",
67
+ orpha_code=orpha_code,
68
+ name=name,
69
+ definition=definition,
70
+ expert_link=expert_link,
71
+ )
72
+
73
+ def add_synonym(self, orpha_code: int, synonym_text: str) -> None:
74
+ disease_nid = f"Disease:{orpha_code}"
75
+ syn_nid = f"Synonym:{synonym_text}"
76
+ self.graph.add_node(syn_nid, type="Synonym", text=synonym_text)
77
+ self.graph.add_edge(disease_nid, syn_nid, label="HAS_SYNONYM")
78
+
79
+ def upsert_disorders_bulk(self, disorders: list[dict]) -> int:
80
+ for d in disorders:
81
+ self.upsert_disease(
82
+ orpha_code=d["orpha_code"],
83
+ name=d["name"],
84
+ definition=d.get("definition", ""),
85
+ expert_link=d.get("expert_link", ""),
86
+ )
87
+ for syn in d.get("synonyms", []):
88
+ self.add_synonym(d["orpha_code"], syn)
89
+ self.save()
90
+ return len(disorders)
91
+
92
+ # ------------------------------------------------------------------
93
+ # HPO write
94
+ # ------------------------------------------------------------------
95
+
96
+ def upsert_hpo_term(self, hpo_id: str, term: str) -> None:
97
+ """Create or update an HPOTerm node."""
98
+ nid = f"HPO:{hpo_id}"
99
+ self.graph.add_node(nid, type="HPOTerm", hpo_id=hpo_id, term=term)
100
+
101
+ def add_manifestation(
102
+ self,
103
+ orpha_code: int,
104
+ hpo_id: str,
105
+ frequency_label: str,
106
+ frequency_order: int,
107
+ diagnostic_criteria: str,
108
+ ) -> None:
109
+ """
110
+ Add (:Disease)-[:MANIFESTS_AS {frequency_label, frequency_order, diagnostic_criteria}]->(:HPOTerm)
111
+ frequency_order: 1=Very frequent, 2=Frequent, 3=Occasional, 4=Rare, 5=Excluded, 0=Unknown
112
+ """
113
+ disease_nid = f"Disease:{orpha_code}"
114
+ hpo_nid = f"HPO:{hpo_id}"
115
+ if disease_nid not in self.graph:
116
+ return # skip if disease not loaded yet
117
+ self.graph.add_edge(
118
+ disease_nid,
119
+ hpo_nid,
120
+ label="MANIFESTS_AS",
121
+ frequency_label=frequency_label,
122
+ frequency_order=frequency_order,
123
+ diagnostic_criteria=diagnostic_criteria,
124
+ )
125
+
126
+ def upsert_hpo_bulk(self, associations: list[dict]) -> int:
127
+ """
128
+ associations: list of {orpha_code, hpo_id, term, frequency_label,
129
+ frequency_order, diagnostic_criteria}
130
+ """
131
+ for a in associations:
132
+ self.upsert_hpo_term(a["hpo_id"], a["term"])
133
+ self.add_manifestation(
134
+ orpha_code=a["orpha_code"],
135
+ hpo_id=a["hpo_id"],
136
+ frequency_label=a["frequency_label"],
137
+ frequency_order=a["frequency_order"],
138
+ diagnostic_criteria=a["diagnostic_criteria"],
139
+ )
140
+ self.save()
141
+ return len(associations)
142
+
143
+ # ------------------------------------------------------------------
144
+ # Disease read
145
+ # ------------------------------------------------------------------
146
+
147
+ def find_disease_by_name(self, name_fragment: str) -> Optional[dict]:
148
+ """Case-insensitive contains search."""
149
+ fragment = name_fragment.lower()
150
+ for nid, attrs in self.graph.nodes(data=True):
151
+ if attrs.get("type") == "Disease":
152
+ if fragment in attrs.get("name", "").lower():
153
+ return self._hydrate_disease(nid, attrs)
154
+ return None
155
+
156
+ def get_disease_by_orpha(self, orpha_code: int) -> Optional[dict]:
157
+ nid = f"Disease:{orpha_code}"
158
+ if nid in self.graph:
159
+ return self._hydrate_disease(nid, self.graph.nodes[nid])
160
+ return None
161
+
162
+ def _hydrate_disease(self, nid: str, attrs: dict) -> dict:
163
+ synonyms, hpo_terms = [], []
164
+ for v, edge_data in self.graph[nid].items():
165
+ vtype = self.graph.nodes[v].get("type")
166
+ if vtype == "Synonym":
167
+ synonyms.append(self.graph.nodes[v]["text"])
168
+ elif vtype == "HPOTerm":
169
+ hpo_terms.append({
170
+ "hpo_id": self.graph.nodes[v]["hpo_id"],
171
+ "term": self.graph.nodes[v]["term"],
172
+ "frequency_label": edge_data.get("frequency_label", ""),
173
+ "frequency_order": edge_data.get("frequency_order", 0),
174
+ "diagnostic_criteria": edge_data.get("diagnostic_criteria", ""),
175
+ })
176
+ hpo_terms.sort(key=lambda x: x["frequency_order"])
177
+ return {
178
+ "orpha_code": attrs["orpha_code"],
179
+ "name": attrs["name"],
180
+ "definition": attrs.get("definition", ""),
181
+ "expert_link": attrs.get("expert_link", ""),
182
+ "synonyms": synonyms,
183
+ "hpo_terms": hpo_terms,
184
+ }
185
+
186
+ # ------------------------------------------------------------------
187
+ # Phenotype-based diagnostic query
188
+ # ------------------------------------------------------------------
189
+
190
+ def find_diseases_by_hpo(
191
+ self,
192
+ hpo_ids: list[str],
193
+ top_n: int = 10,
194
+ min_matches: int = 1,
195
+ ) -> list[dict]:
196
+ """
197
+ Given a list of HPO term IDs, find diseases that manifest those phenotypes.
198
+ Returns diseases ranked by:
199
+ 1. Number of matching HPO terms (desc)
200
+ 2. Sum of frequency weights of matched terms (desc)
201
+ (Very frequent=5, Frequent=4, Occasional=3, Rare=2, Excluded=-1, Unknown=1)
202
+
203
+ This is the core graph-based differential diagnosis query.
204
+ """
205
+ FREQ_WEIGHT = {1: 5, 2: 4, 3: 3, 4: 2, 5: -1, 0: 1}
206
+
207
+ query_nodes = {f"HPO:{hid}" for hid in hpo_ids}
208
+
209
+ # Walk from each HPO node to Disease predecessors
210
+ disease_scores: dict[str, dict] = {}
211
+ for hpo_nid in query_nodes:
212
+ if hpo_nid not in self.graph:
213
+ continue
214
+ for disease_nid in self.graph.predecessors(hpo_nid):
215
+ if self.graph.nodes[disease_nid].get("type") != "Disease":
216
+ continue
217
+ edge = self.graph[disease_nid][hpo_nid]
218
+ if edge.get("label") != "MANIFESTS_AS":
219
+ continue
220
+ # Skip excluded phenotypes
221
+ if edge.get("frequency_order") == 5:
222
+ continue
223
+
224
+ freq_w = FREQ_WEIGHT.get(edge.get("frequency_order", 0), 1)
225
+ if disease_nid not in disease_scores:
226
+ disease_scores[disease_nid] = {
227
+ "match_count": 0,
228
+ "freq_score": 0.0,
229
+ "matched_hpo": [],
230
+ }
231
+ disease_scores[disease_nid]["match_count"] += 1
232
+ disease_scores[disease_nid]["freq_score"] += freq_w
233
+ disease_scores[disease_nid]["matched_hpo"].append({
234
+ "hpo_id": self.graph.nodes[hpo_nid]["hpo_id"],
235
+ "term": self.graph.nodes[hpo_nid]["term"],
236
+ "frequency_label": edge.get("frequency_label", ""),
237
+ })
238
+
239
+ # Filter minimum matches and rank
240
+ ranked = [
241
+ (nid, scores)
242
+ for nid, scores in disease_scores.items()
243
+ if scores["match_count"] >= min_matches
244
+ ]
245
+ ranked.sort(key=lambda x: (x[1]["match_count"], x[1]["freq_score"]), reverse=True)
246
+
247
+ results = []
248
+ for disease_nid, scores in ranked[:top_n]:
249
+ attrs = self.graph.nodes[disease_nid]
250
+ results.append({
251
+ "orpha_code": attrs["orpha_code"],
252
+ "name": attrs["name"],
253
+ "definition": attrs.get("definition", ""),
254
+ "match_count": scores["match_count"],
255
+ "total_query_terms": len(hpo_ids),
256
+ "freq_score": round(scores["freq_score"], 2),
257
+ "matched_hpo": scores["matched_hpo"],
258
+ })
259
+ return results
260
+
261
+ def find_diseases_by_hpo_terms(
262
+ self,
263
+ term_names: list[str],
264
+ top_n: int = 10,
265
+ ) -> list[dict]:
266
+ """
267
+ Convenience wrapper: search by HPO term names (case-insensitive)
268
+ instead of HPO IDs.
269
+ """
270
+ hpo_ids = []
271
+ for name in term_names:
272
+ name_lower = name.lower()
273
+ for nid, attrs in self.graph.nodes(data=True):
274
+ if attrs.get("type") == "HPOTerm":
275
+ if name_lower in attrs.get("term", "").lower():
276
+ hpo_ids.append(attrs["hpo_id"])
277
+ break
278
+ return self.find_diseases_by_hpo(hpo_ids, top_n=top_n)
279
+
280
+ # ------------------------------------------------------------------
281
+ # Stats
282
+ # ------------------------------------------------------------------
283
+
284
+ def disease_count(self) -> int:
285
+ return sum(1 for _, d in self.graph.nodes(data=True) if d.get("type") == "Disease")
286
+
287
+ def synonym_count(self) -> int:
288
+ return sum(1 for _, d in self.graph.nodes(data=True) if d.get("type") == "Synonym")
289
+
290
+ def hpo_term_count(self) -> int:
291
+ return sum(1 for _, d in self.graph.nodes(data=True) if d.get("type") == "HPOTerm")
292
+
293
+ def manifestation_count(self) -> int:
294
+ return sum(
295
+ 1 for _, _, d in self.graph.edges(data=True)
296
+ if d.get("label") == "MANIFESTS_AS"
297
+ )
298
+
299
+ def edge_count(self) -> int:
300
+ return self.graph.number_of_edges()
backend/scripts/hello_world.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ hello_world.py
3
+ --------------
4
+ Week 1 Milestone: Query graph store and ChromaDB simultaneously.
5
+
6
+ Primary: Neo4j (Docker) + ChromaDB HTTP (Docker)
7
+ Fallback: LocalGraphStore (JSON) + ChromaDB PersistentClient (embedded)
8
+
9
+ Demonstrates the RareDx core pattern:
10
+ 1. Retrieve disease structured data from graph store (Neo4j / JSON)
11
+ 2. Find semantically related diseases from ChromaDB (BioLORD-2023 vectors)
12
+ 3. Merge and display results
13
+
14
+ Usage:
15
+ python hello_world.py [disease_name]
16
+ python hello_world.py "Marfan syndrome"
17
+ """
18
+
19
+ import os
20
+ import sys
21
+ import io
22
+ import time
23
+
24
+ # Force UTF-8 output on Windows terminals
25
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
26
+ sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
27
+ import concurrent.futures
28
+ from pathlib import Path
29
+
30
+ import chromadb
31
+ from chromadb.config import Settings
32
+ from sentence_transformers import SentenceTransformer
33
+ from dotenv import load_dotenv
34
+
35
+ load_dotenv(Path(__file__).parents[2] / ".env")
36
+
37
+ NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
38
+ NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
39
+ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "raredx_password")
40
+ CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
41
+ CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000"))
42
+ COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases")
43
+ EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
44
+ CHROMA_PERSIST_DIR = Path(__file__).parents[2] / "data" / "chromadb"
45
+
46
+ QUERY_DISEASE = sys.argv[1] if len(sys.argv) > 1 else "Marfan syndrome"
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # Graph store queries (Neo4j primary, LocalGraphStore fallback)
51
+ # ---------------------------------------------------------------------------
52
+
53
+ def fetch_from_graph(disease_name: str) -> tuple[dict | None, str]:
54
+ """Returns (disease_dict or None, backend_label)."""
55
+
56
+ # Try Neo4j first
57
+ try:
58
+ from neo4j import GraphDatabase
59
+ driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
60
+ driver.verify_connectivity()
61
+ with driver.session() as session:
62
+ result = session.run(
63
+ """
64
+ MATCH (d:Disease)
65
+ WHERE toLower(d.name) CONTAINS toLower($name)
66
+ OPTIONAL MATCH (d)-[:HAS_SYNONYM]->(s:Synonym)
67
+ RETURN
68
+ d.orpha_code AS orpha_code,
69
+ d.name AS name,
70
+ d.definition AS definition,
71
+ d.expert_link AS expert_link,
72
+ collect(s.text) AS synonyms
73
+ LIMIT 1
74
+ """,
75
+ name=disease_name,
76
+ )
77
+ record = result.single()
78
+ driver.close()
79
+ if record:
80
+ return dict(record), "Neo4j (Docker)"
81
+ return None, "Neo4j (Docker)"
82
+
83
+ except Exception:
84
+ pass # fall through to local store
85
+
86
+ # LocalGraphStore fallback
87
+ try:
88
+ from graph_store import LocalGraphStore
89
+ store = LocalGraphStore()
90
+ disease = store.find_disease_by_name(disease_name)
91
+ return disease, "LocalGraphStore (JSON)"
92
+ except Exception as exc:
93
+ print(f" Graph store error: {exc}")
94
+ return None, "unavailable"
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # ChromaDB semantic search (HTTP primary, embedded fallback)
99
+ # ---------------------------------------------------------------------------
100
+
101
+ def get_chroma_client() -> chromadb.ClientAPI:
102
+ try:
103
+ client = chromadb.HttpClient(
104
+ host=CHROMA_HOST,
105
+ port=CHROMA_PORT,
106
+ settings=Settings(anonymized_telemetry=False),
107
+ )
108
+ client.heartbeat()
109
+ return client
110
+ except Exception:
111
+ return chromadb.PersistentClient(
112
+ path=str(CHROMA_PERSIST_DIR),
113
+ settings=Settings(anonymized_telemetry=False),
114
+ )
115
+
116
+
117
+ def fetch_from_chromadb(
118
+ query_text: str,
119
+ model: SentenceTransformer,
120
+ n_results: int = 5,
121
+ ) -> tuple[list[dict], str]:
122
+ client = get_chroma_client()
123
+ backend = "ChromaDB HTTP" if hasattr(client, "_api") else "ChromaDB Embedded"
124
+
125
+ collection = client.get_collection(COLLECTION_NAME)
126
+ embedding = model.encode([query_text], normalize_embeddings=True)
127
+ results = collection.query(
128
+ query_embeddings=embedding.tolist(),
129
+ n_results=n_results,
130
+ include=["documents", "metadatas", "distances"],
131
+ )
132
+
133
+ hits = []
134
+ for meta, dist in zip(results["metadatas"][0], results["distances"][0]):
135
+ hits.append({
136
+ "orpha_code": meta.get("orpha_code"),
137
+ "name": meta.get("name"),
138
+ "definition": meta.get("definition", ""),
139
+ "synonyms": meta.get("synonyms", ""),
140
+ "cosine_similarity": round(1 - dist, 4),
141
+ })
142
+ return hits, backend
143
+
144
+
145
+ # ---------------------------------------------------------------------------
146
+ # Display
147
+ # ---------------------------------------------------------------------------
148
+
149
+ BOLD = "\033[1m"
150
+ CYAN = "\033[96m"
151
+ GREEN = "\033[92m"
152
+ YELLOW= "\033[93m"
153
+ DIM = "\033[2m"
154
+ RESET = "\033[0m"
155
+ LINE = "-" * 62
156
+
157
+
158
+ def _wrap(text: str, width: int = 72, indent: str = " ") -> str:
159
+ words = text.split()
160
+ lines, cur = [], []
161
+ for w in words:
162
+ cur.append(w)
163
+ if len(" ".join(cur)) > width:
164
+ lines.append(indent + " ".join(cur[:-1]))
165
+ cur = [w]
166
+ if cur:
167
+ lines.append(indent + " ".join(cur))
168
+ return "\n".join(lines)
169
+
170
+
171
+ def print_graph_result(disease: dict | None, backend: str) -> None:
172
+ print(f"\n{BOLD}{CYAN}[ Graph Store — {backend} ]{RESET}")
173
+ print(LINE)
174
+ if disease is None:
175
+ print(f" {YELLOW}No match found.{RESET}")
176
+ return
177
+ print(f" {BOLD}OrphaCode :{RESET} ORPHA:{disease['orpha_code']}")
178
+ print(f" {BOLD}Name :{RESET} {disease['name']}")
179
+ if disease.get("synonyms"):
180
+ print(f" {BOLD}Synonyms :{RESET} {', '.join(disease['synonyms'])}")
181
+ if disease.get("definition"):
182
+ print(f" {BOLD}Definition :{RESET}")
183
+ print(_wrap(disease["definition"]))
184
+ if disease.get("expert_link"):
185
+ print(f" {BOLD}OrphaNet :{RESET} {DIM}{disease['expert_link']}{RESET}")
186
+
187
+
188
+ def print_chroma_results(hits: list[dict], backend: str) -> None:
189
+ print(f"\n{BOLD}{GREEN}[ ChromaDB — BioLORD-2023 Semantic Neighbours | {backend} ]{RESET}")
190
+ print(LINE)
191
+ if not hits:
192
+ print(f" {YELLOW}No results.{RESET}")
193
+ return
194
+ for rank, hit in enumerate(hits, 1):
195
+ sim = hit["cosine_similarity"]
196
+ bar_len = int(sim * 20)
197
+ bar = "█" * bar_len + "░" * (20 - bar_len)
198
+ print(f" {rank}. [{bar}] {sim:.4f} ORPHA:{hit['orpha_code']} {hit['name']}")
199
+ if hit.get("synonyms"):
200
+ print(f" {DIM}Also: {hit['synonyms']}{RESET}")
201
+
202
+
203
+ # ---------------------------------------------------------------------------
204
+ # Main
205
+ # ---------------------------------------------------------------------------
206
+
207
+ def main() -> None:
208
+ print("=" * 62)
209
+ print("RareDx — Week 1 Hello World Milestone")
210
+ print("=" * 62)
211
+ print(f"\nQuery: {BOLD}{QUERY_DISEASE}{RESET}\n")
212
+
213
+ # Load BioLORD (needed before spawning threads so it is not loaded twice)
214
+ print(f"Loading BioLORD-2023...")
215
+ t0 = time.time()
216
+ model = SentenceTransformer(EMBED_MODEL)
217
+ print(f" Model ready in {time.time() - t0:.1f}s")
218
+
219
+ # Parallel queries
220
+ print(f"\nQuerying graph store and ChromaDB simultaneously...")
221
+ t_start = time.time()
222
+
223
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
224
+ graph_fut = pool.submit(fetch_from_graph, QUERY_DISEASE)
225
+ chroma_fut = pool.submit(fetch_from_chromadb, QUERY_DISEASE, model, 5)
226
+
227
+ disease, graph_backend = graph_fut.result()
228
+ hits, chroma_backend = chroma_fut.result()
229
+
230
+ elapsed = time.time() - t_start
231
+ print(f" Both queries completed in {elapsed:.2f}s")
232
+
233
+ # Display
234
+ print_graph_result(disease, graph_backend)
235
+ print_chroma_results(hits, chroma_backend)
236
+
237
+ # Summary
238
+ graph_ok = disease is not None
239
+ chroma_ok = len(hits) > 0
240
+
241
+ print(f"\n{LINE}")
242
+ print(f"{BOLD}Week 1 Milestone Summary{RESET}")
243
+ print(LINE)
244
+ print(f" Graph store : {'OK' if graph_ok else 'MISS'} — {graph_backend}")
245
+ print(f" ChromaDB : {'OK' if chroma_ok else 'MISS'} — {chroma_backend}")
246
+ print()
247
+
248
+ if graph_ok and chroma_ok:
249
+ print(f" {BOLD}{GREEN}PASSED{RESET} — Neo4j + ChromaDB both responding.")
250
+ else:
251
+ print(f" {YELLOW}PARTIAL{RESET} — one or more backends had no results.")
252
+ sys.exit(1)
253
+ print()
254
+
255
+
256
+ if __name__ == "__main__":
257
+ main()
backend/scripts/ingest_hpo.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ingest_hpo.py
3
+ -------------
4
+ Parses en_product4.xml and adds HPO phenotype associations to the graph.
5
+
6
+ Each disease gets MANIFESTS_AS edges to HPOTerm nodes:
7
+ (:Disease)-[:MANIFESTS_AS {frequency_label, frequency_order, diagnostic_criteria}]->(:HPOTerm)
8
+
9
+ Frequency ordering (for ranking):
10
+ 1 = Very frequent (99-80%)
11
+ 2 = Frequent (79-30%)
12
+ 3 = Occasional (29-5%)
13
+ 4 = Rare (4-1%)
14
+ 5 = Excluded (always absent — negative phenotype)
15
+ 0 = Unknown / not specified
16
+ """
17
+
18
+ import os
19
+ import sys
20
+ from pathlib import Path
21
+ from lxml import etree
22
+ from tqdm import tqdm
23
+ from dotenv import load_dotenv
24
+
25
+ load_dotenv(Path(__file__).parents[2] / ".env")
26
+
27
+ NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
28
+ NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
29
+ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "raredx_password")
30
+
31
+ HPO_XML = Path(os.getenv("ORPHANET_DATA_DIR", "./data/orphanet")) / "en_product4.xml"
32
+
33
+ FREQUENCY_ORDER = {
34
+ "obligate (100%)": 1,
35
+ "very frequent (99-80%)": 1,
36
+ "frequent (79-30%)": 2,
37
+ "occasional (29-5%)": 3,
38
+ "rare (4-1%)": 4,
39
+ "very rare (<4-1%)": 4,
40
+ "excluded (0%)": 5,
41
+ }
42
+
43
+
44
+ def _text(el, xpath: str) -> str:
45
+ nodes = el.xpath(xpath)
46
+ if not nodes:
47
+ return ""
48
+ val = nodes[0]
49
+ return (val.text or "").strip() if hasattr(val, "text") else str(val).strip()
50
+
51
+
52
+ def parse_hpo_associations(xml_path: Path) -> list[dict]:
53
+ print(f"Parsing {xml_path} ...")
54
+ tree = etree.parse(str(xml_path))
55
+ root = tree.getroot()
56
+
57
+ associations = []
58
+ disorders = root.xpath("//Disorder")
59
+ print(f" Found {len(disorders)} disorders in HPO file.")
60
+
61
+ for disorder in tqdm(disorders, desc=" Parsing disorders", unit="disorder"):
62
+ orpha_code_str = _text(disorder, "OrphaCode")
63
+ if not orpha_code_str:
64
+ continue
65
+ orpha_code = int(orpha_code_str)
66
+
67
+ for assoc in disorder.xpath(".//HPODisorderAssociation"):
68
+ hpo_id = _text(assoc, "HPO/HPOId")
69
+ hpo_term = _text(assoc, "HPO/HPOTerm")
70
+ freq_raw = _text(assoc, "HPOFrequency/Name[@lang='en']").lower()
71
+ diag_crit = _text(assoc, "DiagnosticCriteria/Name[@lang='en']")
72
+
73
+ if not hpo_id or not hpo_term:
74
+ continue
75
+
76
+ freq_order = FREQUENCY_ORDER.get(freq_raw, 0)
77
+
78
+ associations.append({
79
+ "orpha_code": orpha_code,
80
+ "hpo_id": hpo_id,
81
+ "term": hpo_term,
82
+ "frequency_label": freq_raw.capitalize() if freq_raw else "Unknown",
83
+ "frequency_order": freq_order,
84
+ "diagnostic_criteria": diag_crit,
85
+ })
86
+
87
+ print(f" Parsed {len(associations):,} HPO associations.")
88
+ return associations
89
+
90
+
91
+ # ---------------------------------------------------------------------------
92
+ # Neo4j path
93
+ # ---------------------------------------------------------------------------
94
+
95
+ NEO4J_HPO_QUERY = """
96
+ UNWIND $rows AS row
97
+ MERGE (h:HPOTerm {hpo_id: row.hpo_id})
98
+ SET h.term = row.term
99
+ WITH h, row
100
+ MATCH (d:Disease {orpha_code: row.orpha_code})
101
+ MERGE (d)-[r:MANIFESTS_AS {hpo_id: row.hpo_id}]->(h)
102
+ SET r.frequency_label = row.frequency_label,
103
+ r.frequency_order = row.frequency_order,
104
+ r.diagnostic_criteria = row.diagnostic_criteria
105
+ """
106
+
107
+
108
+ def try_neo4j(associations: list[dict]) -> bool:
109
+ try:
110
+ from neo4j import GraphDatabase
111
+ driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
112
+ driver.verify_connectivity()
113
+ except Exception as exc:
114
+ print(f" Neo4j not reachable ({exc}). Using local graph store.")
115
+ return False
116
+
117
+ print(" Neo4j connected.")
118
+ BATCH = 500
119
+ try:
120
+ with driver.session() as session:
121
+ # Index on HPO ID
122
+ session.run(
123
+ "CREATE INDEX hpo_id IF NOT EXISTS FOR (h:HPOTerm) ON (h.hpo_id)"
124
+ )
125
+ total = 0
126
+ for i in range(0, len(associations), BATCH):
127
+ session.run(NEO4J_HPO_QUERY, rows=associations[i : i + BATCH])
128
+ total += len(associations[i : i + BATCH])
129
+ print(f" Ingested {min(total, len(associations)):,} / {len(associations):,} ...", end="\r")
130
+ print()
131
+
132
+ hpo_count = session.run("MATCH (h:HPOTerm) RETURN count(h) AS c").single()["c"]
133
+ rel_count = session.run("MATCH ()-[r:MANIFESTS_AS]->() RETURN count(r) AS c").single()["c"]
134
+ print(f" Neo4j: {hpo_count:,} HPO terms, {rel_count:,} MANIFESTS_AS edges.")
135
+ return True
136
+ finally:
137
+ driver.close()
138
+
139
+
140
+ # ---------------------------------------------------------------------------
141
+ # Local fallback
142
+ # ---------------------------------------------------------------------------
143
+
144
+ def ingest_local(associations: list[dict]) -> None:
145
+ from graph_store import LocalGraphStore
146
+ store = LocalGraphStore()
147
+
148
+ print(f" Loading existing graph ({store.disease_count():,} diseases)...")
149
+ print(f" Adding {len(associations):,} HPO associations...")
150
+
151
+ BATCH = 2000
152
+ for i in tqdm(range(0, len(associations), BATCH), desc=" Ingesting batches"):
153
+ batch = associations[i : i + BATCH]
154
+ for a in batch:
155
+ store.upsert_hpo_term(a["hpo_id"], a["term"])
156
+ store.add_manifestation(
157
+ orpha_code=a["orpha_code"],
158
+ hpo_id=a["hpo_id"],
159
+ frequency_label=a["frequency_label"],
160
+ frequency_order=a["frequency_order"],
161
+ diagnostic_criteria=a["diagnostic_criteria"],
162
+ )
163
+
164
+ print(" Saving graph (this may take a moment for ~26K nodes)...")
165
+ store.save()
166
+ print(f" Graph: {store.disease_count():,} diseases | "
167
+ f"{store.hpo_term_count():,} HPO terms | "
168
+ f"{store.manifestation_count():,} MANIFESTS_AS edges")
169
+ print(f" Saved to {store.path}")
170
+
171
+
172
+ # ---------------------------------------------------------------------------
173
+ # Main
174
+ # ---------------------------------------------------------------------------
175
+
176
+ def main() -> None:
177
+ print("=" * 60)
178
+ print("RareDx — Week 2A: Ingest HPO Phenotype Associations")
179
+ print("=" * 60)
180
+
181
+ if not HPO_XML.exists():
182
+ print(f"ERROR: {HPO_XML} not found. Run download_hpo.py first.")
183
+ sys.exit(1)
184
+
185
+ associations = parse_hpo_associations(HPO_XML)
186
+
187
+ print("\nAttempting Neo4j connection...")
188
+ if try_neo4j(associations):
189
+ backend = "Neo4j (Docker)"
190
+ else:
191
+ ingest_local(associations)
192
+ backend = "LocalGraphStore"
193
+
194
+ print(f"\nStep done — backend: {backend}")
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()
backend/scripts/ingest_neo4j.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ingest_neo4j.py
3
+ ---------------
4
+ Parses the Orphanet en_product1.xml and loads diseases into a graph store.
5
+
6
+ Primary: Neo4j (Docker service at bolt://localhost:7687)
7
+ Fallback: LocalGraphStore (NetworkX + JSON, no server required)
8
+
9
+ Graph Schema (Week 1):
10
+ (:Disease {orpha_code, name, definition, expert_link})
11
+ (:Synonym {text})
12
+ (:Disease)-[:HAS_SYNONYM]->(:Synonym)
13
+ """
14
+
15
+ import os
16
+ import sys
17
+ import time
18
+ from pathlib import Path
19
+ from lxml import etree
20
+ from dotenv import load_dotenv
21
+
22
+ load_dotenv(Path(__file__).parents[2] / ".env")
23
+
24
+ NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
25
+ NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
26
+ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "raredx_password")
27
+ XML_PATH = Path(os.getenv("ORPHANET_XML", "./data/orphanet/en_product1.xml"))
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # XML parsing
32
+ # ---------------------------------------------------------------------------
33
+
34
+ def _text(element, xpath: str) -> str:
35
+ nodes = element.xpath(xpath)
36
+ if nodes:
37
+ val = nodes[0]
38
+ return (val.text or "").strip() if hasattr(val, "text") else str(val).strip()
39
+ return ""
40
+
41
+
42
+ def parse_disorders(xml_path: Path) -> list[dict]:
43
+ print(f"Parsing {xml_path} ...")
44
+ tree = etree.parse(str(xml_path))
45
+ root = tree.getroot()
46
+
47
+ disorders = []
48
+ for disorder in root.xpath("//Disorder"):
49
+ orpha_code = _text(disorder, "OrphaCode")
50
+ name = _text(disorder, "Name[@lang='en']")
51
+ definition = _text(disorder, "TextAuto[@lang='en']")
52
+ expert_link = _text(disorder, "ExpertLink[@lang='en']")
53
+ synonyms = [
54
+ s.text.strip()
55
+ for s in disorder.xpath("SynonymList/Synonym[@lang='en']")
56
+ if s.text and s.text.strip()
57
+ ]
58
+ if not orpha_code or not name:
59
+ continue
60
+ disorders.append(
61
+ {
62
+ "orpha_code": int(orpha_code),
63
+ "name": name,
64
+ "definition": definition,
65
+ "expert_link": expert_link,
66
+ "synonyms": synonyms,
67
+ }
68
+ )
69
+
70
+ print(f" Parsed {len(disorders)} disorders.")
71
+ return disorders
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Neo4j path (Docker)
76
+ # ---------------------------------------------------------------------------
77
+
78
+ def try_neo4j(disorders: list[dict]) -> bool:
79
+ """
80
+ Attempt to connect to Neo4j and ingest data.
81
+ Returns True on success, False if Neo4j is unavailable.
82
+ """
83
+ try:
84
+ from neo4j import GraphDatabase
85
+ except ImportError:
86
+ return False
87
+
88
+ try:
89
+ driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
90
+ # Fast connectivity check (1-second timeout)
91
+ driver.verify_connectivity()
92
+ except Exception as exc:
93
+ print(f" Neo4j not reachable ({exc}). Falling back to local graph store.")
94
+ return False
95
+
96
+ print(" Neo4j connected.")
97
+ try:
98
+ with driver.session() as session:
99
+ _neo4j_setup_schema(session)
100
+ ingested = _neo4j_upsert(session, disorders)
101
+ _neo4j_verify(session)
102
+ print(f" Ingested {ingested} diseases into Neo4j.")
103
+ return True
104
+ finally:
105
+ driver.close()
106
+
107
+
108
+ def _neo4j_setup_schema(session) -> None:
109
+ session.run(
110
+ "CREATE CONSTRAINT disease_orpha_code IF NOT EXISTS "
111
+ "FOR (d:Disease) REQUIRE d.orpha_code IS UNIQUE"
112
+ )
113
+ session.run(
114
+ "CREATE INDEX disease_name IF NOT EXISTS FOR (d:Disease) ON (d.name)"
115
+ )
116
+ session.run(
117
+ "CREATE CONSTRAINT synonym_text IF NOT EXISTS "
118
+ "FOR (s:Synonym) REQUIRE s.text IS UNIQUE"
119
+ )
120
+ print(" Schema constraints created.")
121
+
122
+
123
+ def _neo4j_upsert(session, disorders: list[dict]) -> int:
124
+ query = """
125
+ UNWIND $rows AS row
126
+ MERGE (d:Disease {orpha_code: row.orpha_code})
127
+ SET d.name = row.name,
128
+ d.definition = row.definition,
129
+ d.expert_link = row.expert_link
130
+ WITH d, row
131
+ UNWIND row.synonyms AS syn_text
132
+ MERGE (s:Synonym {text: syn_text})
133
+ MERGE (d)-[:HAS_SYNONYM]->(s)
134
+ """
135
+ BATCH = 200
136
+ total = 0
137
+ for i in range(0, len(disorders), BATCH):
138
+ session.run(query, rows=disorders[i : i + BATCH])
139
+ total += len(disorders[i : i + BATCH])
140
+ print(f" Ingested {min(total, len(disorders))} / {len(disorders)} ...", end="\r")
141
+ print()
142
+ return total
143
+
144
+
145
+ def _neo4j_verify(session) -> None:
146
+ dc = session.run("MATCH (d:Disease) RETURN count(d) AS c").single()["c"]
147
+ sc = session.run("MATCH (s:Synonym) RETURN count(s) AS c").single()["c"]
148
+ rc = session.run("MATCH ()-[r:HAS_SYNONYM]->() RETURN count(r) AS c").single()["c"]
149
+ print(f" Neo4j counts: {dc} diseases, {sc} synonyms, {rc} edges.")
150
+
151
+
152
+ # ---------------------------------------------------------------------------
153
+ # Local fallback path (NetworkX + JSON)
154
+ # ---------------------------------------------------------------------------
155
+
156
+ def ingest_local(disorders: list[dict]) -> None:
157
+ from graph_store import LocalGraphStore
158
+ store = LocalGraphStore()
159
+ ingested = store.upsert_disorders_bulk(disorders)
160
+ print(f" LocalGraphStore: {store.disease_count()} diseases, "
161
+ f"{store.synonym_count()} synonyms, {store.edge_count()} edges.")
162
+ print(f" Saved to {store.path}")
163
+
164
+
165
+ # ---------------------------------------------------------------------------
166
+ # Main
167
+ # ---------------------------------------------------------------------------
168
+
169
+ def main() -> None:
170
+ print("=" * 60)
171
+ print("RareDx — Step 2: Ingest Orphanet Data into Graph Store")
172
+ print("=" * 60)
173
+
174
+ if not XML_PATH.exists():
175
+ print(f"ERROR: XML not found at {XML_PATH}. Run download_orphanet.py first.")
176
+ sys.exit(1)
177
+
178
+ disorders = parse_disorders(XML_PATH)
179
+
180
+ print("\nAttempting Neo4j connection...")
181
+ if try_neo4j(disorders):
182
+ backend = "Neo4j (Docker)"
183
+ else:
184
+ print("Using local graph store (NetworkX + JSON)...")
185
+ ingest_local(disorders)
186
+ backend = "LocalGraphStore"
187
+
188
+ print(f"\nStep 2 complete — backend: {backend}")
189
+
190
+
191
+ if __name__ == "__main__":
192
+ main()
backend/scripts/milestone_2a.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ milestone_2a.py
3
+ ---------------
4
+ Week 2A Milestone: Symptom-to-candidate-disease via graph phenotype matching.
5
+
6
+ Given a list of clinical symptoms, this script:
7
+ 1. Maps symptoms to HPO term IDs via the graph (HPOTerm name search)
8
+ 2. Runs the MANIFESTS_AS graph traversal to find matching diseases
9
+ 3. Runs BioLORD-2023 semantic search in ChromaDB in parallel
10
+ 4. Merges both rankings into a unified differential diagnosis list
11
+
12
+ This is the first real diagnostic query in RareDx.
13
+
14
+ Usage:
15
+ python milestone_2a.py
16
+ python milestone_2a.py "arachnodactyly" "ectopia lentis" "aortic dilation"
17
+ """
18
+
19
+ import io
20
+ import os
21
+ import sys
22
+ import time
23
+ import concurrent.futures
24
+ from pathlib import Path
25
+
26
+ import chromadb
27
+ from chromadb.config import Settings
28
+ from sentence_transformers import SentenceTransformer
29
+ from dotenv import load_dotenv
30
+
31
+ # UTF-8 output for Windows
32
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
33
+ sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
34
+
35
+ load_dotenv(Path(__file__).parents[2] / ".env")
36
+
37
+ CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
38
+ CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000"))
39
+ COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases")
40
+ EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
41
+ CHROMA_PERSIST_DIR = Path(__file__).parents[2] / "data" / "chromadb"
42
+
43
+ # Default test case: classic Marfan syndrome presentation
44
+ DEFAULT_SYMPTOMS = [
45
+ "arachnodactyly",
46
+ "ectopia lentis",
47
+ "aortic root dilatation",
48
+ "scoliosis",
49
+ "tall stature",
50
+ ]
51
+
52
+ symptoms = sys.argv[1:] if len(sys.argv) > 1 else DEFAULT_SYMPTOMS
53
+
54
+
55
+ # ---------------------------------------------------------------------------
56
+ # Graph query
57
+ # ---------------------------------------------------------------------------
58
+
59
+ def graph_search(symptom_list: list[str]) -> tuple[list[dict], list[str], str]:
60
+ """
61
+ Returns (ranked_diseases, resolved_hpo_ids, backend_label).
62
+ Tries Neo4j first, then LocalGraphStore.
63
+ """
64
+ # Try Neo4j
65
+ try:
66
+ from neo4j import GraphDatabase
67
+ from neo4j import GraphDatabase as gdb
68
+ neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
69
+ neo4j_user = os.getenv("NEO4J_USER", "neo4j")
70
+ neo4j_pass = os.getenv("NEO4J_PASSWORD", "raredx_password")
71
+ driver = gdb.driver(neo4j_uri, auth=(neo4j_user, neo4j_pass))
72
+ driver.verify_connectivity()
73
+
74
+ with driver.session() as session:
75
+ # Resolve symptoms to HPO IDs
76
+ hpo_ids = []
77
+ for sym in symptom_list:
78
+ r = session.run(
79
+ "MATCH (h:HPOTerm) WHERE toLower(h.term) CONTAINS toLower($s) "
80
+ "RETURN h.hpo_id AS hid LIMIT 1",
81
+ s=sym,
82
+ )
83
+ rec = r.single()
84
+ if rec:
85
+ hpo_ids.append(rec["hid"])
86
+
87
+ if not hpo_ids:
88
+ driver.close()
89
+ return [], [], "Neo4j (Docker)"
90
+
91
+ # Graph traversal
92
+ result = session.run(
93
+ """
94
+ UNWIND $hpo_ids AS hid
95
+ MATCH (d:Disease)-[r:MANIFESTS_AS]->(h:HPOTerm {hpo_id: hid})
96
+ WHERE r.frequency_order <> 5
97
+ WITH d, count(h) AS match_count,
98
+ sum(CASE r.frequency_order
99
+ WHEN 1 THEN 5 WHEN 2 THEN 4
100
+ WHEN 3 THEN 3 WHEN 4 THEN 2
101
+ ELSE 1 END) AS freq_score,
102
+ collect({hpo_id: h.hpo_id, term: h.term,
103
+ freq: r.frequency_label}) AS matched_hpo
104
+ WHERE match_count >= 1
105
+ RETURN d.orpha_code AS orpha_code, d.name AS name,
106
+ d.definition AS definition,
107
+ match_count, freq_score, matched_hpo
108
+ ORDER BY match_count DESC, freq_score DESC
109
+ LIMIT 10
110
+ """,
111
+ hpo_ids=hpo_ids,
112
+ )
113
+ diseases = [dict(r) for r in result]
114
+
115
+ driver.close()
116
+ return diseases, hpo_ids, "Neo4j (Docker)"
117
+
118
+ except Exception:
119
+ pass
120
+
121
+ # LocalGraphStore fallback
122
+ from graph_store import LocalGraphStore
123
+ store = LocalGraphStore()
124
+
125
+ # Resolve symptom strings to HPO IDs
126
+ hpo_ids = []
127
+ for sym in symptom_list:
128
+ sym_lower = sym.lower()
129
+ for nid, attrs in store.graph.nodes(data=True):
130
+ if attrs.get("type") == "HPOTerm":
131
+ if sym_lower in attrs.get("term", "").lower():
132
+ hpo_ids.append(attrs["hpo_id"])
133
+ break
134
+
135
+ diseases = store.find_diseases_by_hpo(hpo_ids, top_n=10)
136
+ return diseases, hpo_ids, "LocalGraphStore (JSON)"
137
+
138
+
139
+ # ---------------------------------------------------------------------------
140
+ # ChromaDB semantic search
141
+ # ---------------------------------------------------------------------------
142
+
143
+ def chroma_search(
144
+ symptom_list: list[str],
145
+ model: SentenceTransformer,
146
+ n: int = 10,
147
+ ) -> tuple[list[dict], str]:
148
+ """Embed symptom list as a clinical query and search ChromaDB."""
149
+ query = "Patient presents with: " + ", ".join(symptom_list) + "."
150
+
151
+ try:
152
+ client = chromadb.HttpClient(
153
+ host=CHROMA_HOST,
154
+ port=CHROMA_PORT,
155
+ settings=Settings(anonymized_telemetry=False),
156
+ )
157
+ client.heartbeat()
158
+ backend = "ChromaDB HTTP"
159
+ except Exception:
160
+ client = chromadb.PersistentClient(
161
+ path=str(CHROMA_PERSIST_DIR),
162
+ settings=Settings(anonymized_telemetry=False),
163
+ )
164
+ backend = "ChromaDB Embedded"
165
+
166
+ collection = client.get_collection(COLLECTION_NAME)
167
+ embedding = model.encode([query], normalize_embeddings=True)
168
+ results = collection.query(
169
+ query_embeddings=embedding.tolist(),
170
+ n_results=n,
171
+ include=["metadatas", "distances"],
172
+ )
173
+
174
+ hits = []
175
+ for meta, dist in zip(results["metadatas"][0], results["distances"][0]):
176
+ hits.append({
177
+ "orpha_code": meta.get("orpha_code"),
178
+ "name": meta.get("name"),
179
+ "definition": meta.get("definition", ""),
180
+ "cosine_similarity": round(1 - dist, 4),
181
+ })
182
+ return hits, backend
183
+
184
+
185
+ # ---------------------------------------------------------------------------
186
+ # Score fusion
187
+ # ---------------------------------------------------------------------------
188
+
189
+ def fuse_rankings(
190
+ graph_results: list[dict],
191
+ chroma_results: list[dict],
192
+ ) -> list[dict]:
193
+ """
194
+ Reciprocal Rank Fusion (RRF) of graph and semantic rankings.
195
+ RRF score = sum(1 / (k + rank)) for each list the disease appears in.
196
+ k=60 is the standard constant.
197
+ """
198
+ K = 60
199
+ scores: dict[str, dict] = {}
200
+
201
+ for rank, d in enumerate(graph_results, 1):
202
+ key = str(d["orpha_code"])
203
+ if key not in scores:
204
+ scores[key] = {"orpha_code": d["orpha_code"], "name": d["name"],
205
+ "definition": d.get("definition", ""),
206
+ "graph_rank": None, "chroma_rank": None,
207
+ "graph_matches": None, "chroma_sim": None,
208
+ "rrf_score": 0.0}
209
+ scores[key]["rrf_score"] += 1 / (K + rank)
210
+ scores[key]["graph_rank"] = rank
211
+ scores[key]["graph_matches"] = d.get("match_count", 0)
212
+
213
+ for rank, d in enumerate(chroma_results, 1):
214
+ key = str(d["orpha_code"])
215
+ if key not in scores:
216
+ scores[key] = {"orpha_code": d["orpha_code"], "name": d["name"],
217
+ "definition": d.get("definition", ""),
218
+ "graph_rank": None, "chroma_rank": None,
219
+ "graph_matches": None, "chroma_sim": None,
220
+ "rrf_score": 0.0}
221
+ scores[key]["rrf_score"] += 1 / (K + rank)
222
+ scores[key]["chroma_rank"] = rank
223
+ scores[key]["chroma_sim"] = d.get("cosine_similarity")
224
+
225
+ return sorted(scores.values(), key=lambda x: x["rrf_score"], reverse=True)
226
+
227
+
228
+ # ---------------------------------------------------------------------------
229
+ # Display
230
+ # ---------------------------------------------------------------------------
231
+
232
+ BOLD = "\033[1m"
233
+ CYAN = "\033[96m"
234
+ GREEN = "\033[92m"
235
+ YELLOW = "\033[93m"
236
+ MAGENTA= "\033[95m"
237
+ DIM = "\033[2m"
238
+ RESET = "\033[0m"
239
+ LINE = "-" * 66
240
+
241
+
242
+ def print_section(title: str, color: str) -> None:
243
+ print(f"\n{BOLD}{color}{title}{RESET}")
244
+ print(LINE)
245
+
246
+
247
+ def print_graph_hits(diseases: list[dict], hpo_ids: list[str], backend: str) -> None:
248
+ print_section(f"[ Graph Traversal — {backend} ]", CYAN)
249
+ if not diseases:
250
+ print(f" {YELLOW}No graph matches. HPO IDs resolved: {hpo_ids}{RESET}")
251
+ return
252
+ print(f" {DIM}HPO IDs resolved: {', '.join(hpo_ids)}{RESET}\n")
253
+ for rank, d in enumerate(diseases[:5], 1):
254
+ mc = d.get("match_count", d.get("match_count", "?"))
255
+ total = d.get("total_query_terms", len(symptoms))
256
+ print(f" {rank}. ORPHA:{d['orpha_code']} {BOLD}{d['name']}{RESET}")
257
+ print(f" Phenotype matches: {mc}/{total}")
258
+ matched = d.get("matched_hpo", [])
259
+ if matched:
260
+ terms = ", ".join(m["term"] for m in matched[:4])
261
+ print(f" {DIM}Matched: {terms}{RESET}")
262
+
263
+
264
+ def print_chroma_hits(hits: list[dict], backend: str) -> None:
265
+ print_section(f"[ Semantic Search — BioLORD-2023 | {backend} ]", GREEN)
266
+ for rank, h in enumerate(hits[:5], 1):
267
+ sim = h["cosine_similarity"]
268
+ bar = "█" * int(sim * 20) + "░" * (20 - int(sim * 20))
269
+ print(f" {rank}. [{bar}] {sim:.4f} ORPHA:{h['orpha_code']} {h['name']}")
270
+
271
+
272
+ def print_fused(fused: list[dict]) -> None:
273
+ print_section("[ Fused Differential Diagnosis (RRF) ]", MAGENTA)
274
+ print(f" {'Rank':<5} {'RRF':>6} {'Graph':>5} {'Chroma':>6} Disease")
275
+ print(f" {'-'*4} {'-'*6} {'-'*5} {'-'*6} {'-'*35}")
276
+ for rank, d in enumerate(fused[:10], 1):
277
+ gr = f"#{d['graph_rank']}" if d["graph_rank"] else " - "
278
+ cr = f"#{d['chroma_rank']}" if d["chroma_rank"] else " - "
279
+ rrf = d["rrf_score"]
280
+ print(f" {rank:<5} {rrf:.4f} {gr:>5} {cr:>6} {d['name']}")
281
+
282
+
283
+ # ---------------------------------------------------------------------------
284
+ # Main
285
+ # ---------------------------------------------------------------------------
286
+
287
+ def main() -> None:
288
+ print("=" * 66)
289
+ print("RareDx — Week 2A Milestone: Symptom-to-Diagnosis")
290
+ print("=" * 66)
291
+ print(f"\n{BOLD}Clinical query symptoms:{RESET}")
292
+ for s in symptoms:
293
+ print(f" - {s}")
294
+
295
+ print(f"\nLoading BioLORD-2023 ...")
296
+ t0 = time.time()
297
+ model = SentenceTransformer(EMBED_MODEL)
298
+ print(f" Ready in {time.time()-t0:.1f}s")
299
+
300
+ # Parallel: graph traversal + semantic search
301
+ print("\nRunning graph traversal and semantic search in parallel...")
302
+ t_start = time.time()
303
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
304
+ graph_fut = pool.submit(graph_search, symptoms)
305
+ chroma_fut = pool.submit(chroma_search, symptoms, model, 10)
306
+
307
+ graph_diseases, hpo_ids, graph_backend = graph_fut.result()
308
+ chroma_hits, chroma_backend = chroma_fut.result()
309
+
310
+ elapsed = time.time() - t_start
311
+ print(f" Completed in {elapsed:.2f}s")
312
+
313
+ # Display individual results
314
+ print_graph_hits(graph_diseases, hpo_ids, graph_backend)
315
+ print_chroma_hits(chroma_hits, chroma_backend)
316
+
317
+ # Fuse
318
+ fused = fuse_rankings(graph_diseases, chroma_hits)
319
+ print_fused(fused)
320
+
321
+ # Summary
322
+ graph_ok = len(graph_diseases) > 0
323
+ chroma_ok = len(chroma_hits) > 0
324
+ fused_ok = len(fused) > 0
325
+
326
+ print(f"\n{LINE}")
327
+ print(f"{BOLD}Week 2A Milestone Summary{RESET}")
328
+ print(LINE)
329
+ print(f" Graph traversal : {'OK' if graph_ok else 'MISS'} — {len(graph_diseases)} candidates — {graph_backend}")
330
+ print(f" Semantic search : {'OK' if chroma_ok else 'MISS'} — {len(chroma_hits)} candidates — {chroma_backend}")
331
+ print(f" Fused ranking : {'OK' if fused_ok else 'MISS'} — {len(fused)} unique candidates")
332
+ print()
333
+
334
+ if graph_ok and chroma_ok and fused_ok:
335
+ top = fused[0]
336
+ print(f" {BOLD}{GREEN}PASSED{RESET} — Top diagnosis: {top['name']} (ORPHA:{top['orpha_code']})")
337
+ else:
338
+ print(f" {YELLOW}PARTIAL or FAILED — check individual backends above{RESET}")
339
+ sys.exit(1)
340
+ print()
341
+
342
+
343
+ if __name__ == "__main__":
344
+ main()
backend/scripts/milestone_2b.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ milestone_2b.py
3
+ ---------------
4
+ Week 2B Milestone: Free-text clinical note → differential diagnosis.
5
+
6
+ Tests the full pipeline end-to-end:
7
+ Clinical note
8
+ -> SymptomParser (BioLORD semantic HPO mapping)
9
+ -> Graph traversal (MANIFESTS_AS phenotype matching)
10
+ -> ChromaDB semantic search (HPO-enriched embeddings)
11
+ -> RRF fusion
12
+ -> Ranked differential diagnosis
13
+
14
+ Target note:
15
+ "18 year old male, extremely tall, displaced lens in left eye,
16
+ heart murmur, flexible joints, scoliosis"
17
+
18
+ Expected: Marfan syndrome (ORPHA:558) in top 3.
19
+ """
20
+
21
+ import io
22
+ import sys
23
+ import time
24
+ from pathlib import Path
25
+
26
+ # UTF-8 output for Windows
27
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
28
+ sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
29
+
30
+ # Make sure both scripts/ and api/ are importable
31
+ ROOT = Path(__file__).parents[2]
32
+ sys.path.insert(0, str(ROOT / "backend" / "scripts"))
33
+ sys.path.insert(0, str(ROOT / "backend"))
34
+
35
+ from api.pipeline import DiagnosisPipeline
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Test case
39
+ # ---------------------------------------------------------------------------
40
+
41
+ NOTE = (
42
+ "18 year old male, extremely tall, displaced lens in left eye, "
43
+ "heart murmur, flexible joints, scoliosis"
44
+ )
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Display helpers
48
+ # ---------------------------------------------------------------------------
49
+
50
+ BOLD = "\033[1m"
51
+ CYAN = "\033[96m"
52
+ GREEN = "\033[92m"
53
+ YELLOW = "\033[93m"
54
+ MAGENTA = "\033[95m"
55
+ RED = "\033[91m"
56
+ DIM = "\033[2m"
57
+ RESET = "\033[0m"
58
+ LINE = "-" * 68
59
+
60
+
61
+ def section(title: str, color: str) -> None:
62
+ print(f"\n{BOLD}{color}{title}{RESET}")
63
+ print(LINE)
64
+
65
+
66
+ def print_hpo_matches(matches: list[dict]) -> None:
67
+ section("[ Step 1 — Symptom Parser: Free-text -> HPO Terms ]", CYAN)
68
+ if not matches:
69
+ print(f" {YELLOW}No HPO terms resolved.{RESET}")
70
+ return
71
+ print(f" {'Score':>6} {'HPO ID':<12} {'HPO Term':<38} Phrase")
72
+ print(f" {'-'*6} {'-'*12} {'-'*38} {'-'*28}")
73
+ for m in matches:
74
+ print(f" {m['score']:>6.4f} {m['hpo_id']:<12} {m['term']:<38} \"{m['phrase']}\"")
75
+
76
+
77
+ def print_candidates(candidates: list[dict], n: int = 10) -> None:
78
+ section("[ Step 4 — Fused Differential Diagnosis (RRF) ]", MAGENTA)
79
+ print(f" {'#':<4} {'RRF':>7} {'Graph':>6} {'Vec':>5} {'Match':>5} Disease")
80
+ print(f" {'-'*4} {'-'*7} {'-'*6} {'-'*5} {'-'*5} {'-'*38}")
81
+
82
+ for c in candidates[:n]:
83
+ gr = f"#{c['graph_rank']}" if c.get("graph_rank") else " - "
84
+ cr = f"#{c['chroma_rank']}" if c.get("chroma_rank") else " - "
85
+ mc = str(c.get("graph_matches", "-")) if c.get("graph_matches") is not None else " - "
86
+ name = c["name"][:42]
87
+
88
+ # Highlight Marfan
89
+ highlight = BOLD + GREEN if "Marfan" in c["name"] else ""
90
+ reset_hl = RESET if highlight else ""
91
+
92
+ print(
93
+ f" {c['rank']:<4} {c['rrf_score']:>7.5f} {gr:>6} {cr:>5} {mc:>5} "
94
+ f"{highlight}{name}{reset_hl}"
95
+ )
96
+
97
+ # Show matched phenotypes for top 3
98
+ if c["rank"] <= 3 and c.get("matched_hpo"):
99
+ terms = ", ".join(h["term"] for h in c["matched_hpo"][:5])
100
+ print(f" {DIM}Phenotypes: {terms}{RESET}")
101
+
102
+
103
+ # ---------------------------------------------------------------------------
104
+ # Milestone validation
105
+ # ---------------------------------------------------------------------------
106
+
107
+ def validate(result: dict) -> bool:
108
+ """Pass if Marfan syndrome appears in top 5."""
109
+ candidates = result.get("candidates", [])
110
+ for c in candidates[:5]:
111
+ if "558" in str(c.get("orpha_code", "")) or "Marfan syndrome" == c.get("name", ""):
112
+ return True
113
+ return False
114
+
115
+
116
+ # ---------------------------------------------------------------------------
117
+ # Main
118
+ # ---------------------------------------------------------------------------
119
+
120
+ def main() -> None:
121
+ print("=" * 68)
122
+ print("RareDx — Week 2B Milestone: Clinical Note -> Diagnosis")
123
+ print("=" * 68)
124
+ print(f"\n{BOLD}Clinical note:{RESET}")
125
+ print(f" \"{NOTE}\"\n")
126
+
127
+ # Initialise pipeline (loads model + HPO index + graph + ChromaDB)
128
+ t0 = time.time()
129
+ pipeline = DiagnosisPipeline()
130
+ print(f"\nPipeline initialised in {time.time()-t0:.1f}s\n")
131
+
132
+ # Run diagnosis
133
+ print(f"Running diagnosis...")
134
+ result = pipeline.diagnose(NOTE, top_n=15, threshold=0.52)
135
+ print(f" Completed in {result['elapsed_seconds']}s")
136
+
137
+ # Display
138
+ print_hpo_matches(result["hpo_matches"])
139
+
140
+ section("[ Step 2+3 — Graph + Semantic Search Summary ]", CYAN)
141
+ hpo_used = result["hpo_ids_used"]
142
+ print(f" HPO IDs fed to graph: {', '.join(hpo_used) if hpo_used else 'none'}")
143
+ print(f" Graph candidates: {sum(1 for c in result['candidates'] if c.get('graph_rank'))}")
144
+ print(f" ChromaDB candidates: {sum(1 for c in result['candidates'] if c.get('chroma_rank'))}")
145
+ print(f" Overlap (both): {sum(1 for c in result['candidates'] if c.get('graph_rank') and c.get('chroma_rank'))}")
146
+
147
+ print_candidates(result["candidates"])
148
+
149
+ # Summary
150
+ passed = validate(result)
151
+ top = result.get("top_diagnosis", {})
152
+
153
+ print(f"\n{LINE}")
154
+ print(f"{BOLD}Week 2B Milestone Summary{RESET}")
155
+ print(LINE)
156
+ print(f" HPO terms resolved : {len(result['hpo_matches'])} / {len(result['phrases_extracted'])} phrases matched")
157
+ print(f" Total candidates : {len(result['candidates'])} unique diseases")
158
+ print(f" Graph backend : {result['graph_backend']}")
159
+ print(f" ChromaDB backend : {result['chroma_backend']}")
160
+ print(f" Elapsed : {result['elapsed_seconds']}s")
161
+ print()
162
+
163
+ if passed:
164
+ marfan_rank = next(
165
+ (c["rank"] for c in result["candidates"]
166
+ if "Marfan syndrome" == c.get("name") or "558" in str(c.get("orpha_code", ""))),
167
+ "?",
168
+ )
169
+ print(f" {BOLD}{GREEN}PASSED{RESET} — Marfan syndrome (ORPHA:558) at rank #{marfan_rank}")
170
+ else:
171
+ print(f" {RED}FAILED{RESET} — Marfan syndrome not in top 5")
172
+ print(f" Top result: {top.get('name')} (ORPHA:{top.get('orpha_code')})")
173
+ sys.exit(1)
174
+
175
+ print()
176
+ print(f" {BOLD}Top diagnosis:{RESET} {top.get('name')} (ORPHA:{top.get('orpha_code')})")
177
+ if top.get("definition"):
178
+ words = top["definition"].split()
179
+ snippet = " ".join(words[:30]) + ("..." if len(words) > 30 else "")
180
+ print(f" {DIM}{snippet}{RESET}")
181
+ print()
182
+
183
+
184
+ if __name__ == "__main__":
185
+ main()
backend/scripts/reembed_chromadb.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ reembed_chromadb.py
3
+ -------------------
4
+ Rebuilds ChromaDB embeddings with HPO-enriched disease descriptions.
5
+
6
+ Week 1 embedding text:
7
+ "{name}. {definition}. Also known as: {synonyms}."
8
+
9
+ Week 2B embedding text (this script):
10
+ "{name}. {definition}. Phenotypes: {hpo_terms ordered by frequency}.
11
+ Also known as: {synonyms}."
12
+
13
+ Adding phenotype terms directly into the embedding space means ChromaDB
14
+ can now find diseases by symptoms, not just by name similarity.
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ import chromadb
22
+ from chromadb.config import Settings
23
+ from sentence_transformers import SentenceTransformer
24
+ from tqdm import tqdm
25
+ from dotenv import load_dotenv
26
+
27
+ load_dotenv(Path(__file__).parents[2] / ".env")
28
+
29
+ CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
30
+ CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000"))
31
+ COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases")
32
+ EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
33
+ CHROMA_PERSIST = Path(__file__).parents[2] / "data" / "chromadb"
34
+
35
+ BATCH_SIZE = 32
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Build enriched document text per disease
40
+ # ---------------------------------------------------------------------------
41
+
42
+ def build_documents(store) -> list[dict]:
43
+ """
44
+ Pull every disease from the graph store and build HPO-enriched embed text.
45
+ HPO terms are sorted by frequency_order (most frequent first).
46
+ """
47
+ docs = []
48
+ disease_nodes = [
49
+ (nid, attrs)
50
+ for nid, attrs in store.graph.nodes(data=True)
51
+ if attrs.get("type") == "Disease"
52
+ ]
53
+
54
+ for nid, attrs in tqdm(disease_nodes, desc=" Building documents", unit="disease"):
55
+ orpha_code = attrs["orpha_code"]
56
+ name = attrs.get("name", "")
57
+ definition = attrs.get("definition", "")
58
+
59
+ # Collect synonyms and HPO terms from graph edges
60
+ synonyms = []
61
+ hpo_terms = []
62
+
63
+ for v, edata in store.graph[nid].items():
64
+ vattrs = store.graph.nodes[v]
65
+ vtype = vattrs.get("type")
66
+
67
+ if vtype == "Synonym":
68
+ synonyms.append(vattrs["text"])
69
+
70
+ elif vtype == "HPOTerm" and edata.get("label") == "MANIFESTS_AS":
71
+ freq_order = edata.get("frequency_order", 9)
72
+ # Skip excluded phenotypes (frequency_order == 5)
73
+ if freq_order == 5:
74
+ continue
75
+ hpo_terms.append((freq_order, vattrs.get("term", "")))
76
+
77
+ # Sort HPO terms: most frequent first
78
+ hpo_terms.sort(key=lambda x: x[0])
79
+ hpo_term_names = [t[1] for t in hpo_terms[:30]] # cap at 30 to control token length
80
+
81
+ # Build enriched text
82
+ parts = [name]
83
+ if definition:
84
+ parts.append(definition)
85
+ if hpo_term_names:
86
+ parts.append("Clinical features: " + ", ".join(hpo_term_names) + ".")
87
+ if synonyms:
88
+ parts.append("Also known as: " + ", ".join(synonyms) + ".")
89
+
90
+ embed_text = " ".join(parts)
91
+
92
+ docs.append({
93
+ "id": f"ORPHA:{orpha_code}",
94
+ "orpha_code": str(orpha_code),
95
+ "name": name,
96
+ "definition": definition,
97
+ "synonyms": ", ".join(synonyms),
98
+ "hpo_terms": ", ".join(hpo_term_names[:15]), # store subset in metadata
99
+ "embed_text": embed_text,
100
+ })
101
+
102
+ return docs
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # ChromaDB helpers
107
+ # ---------------------------------------------------------------------------
108
+
109
+ def get_chroma_client() -> tuple[chromadb.ClientAPI, str]:
110
+ try:
111
+ client = chromadb.HttpClient(
112
+ host=CHROMA_HOST, port=CHROMA_PORT,
113
+ settings=Settings(anonymized_telemetry=False),
114
+ )
115
+ client.heartbeat()
116
+ return client, "ChromaDB HTTP (Docker)"
117
+ except Exception:
118
+ CHROMA_PERSIST.mkdir(parents=True, exist_ok=True)
119
+ client = chromadb.PersistentClient(
120
+ path=str(CHROMA_PERSIST),
121
+ settings=Settings(anonymized_telemetry=False),
122
+ )
123
+ return client, "ChromaDB Embedded"
124
+
125
+
126
+ def recreate_collection(client: chromadb.ClientAPI, name: str) -> chromadb.Collection:
127
+ try:
128
+ client.delete_collection(name)
129
+ print(f" Deleted existing collection '{name}'.")
130
+ except Exception:
131
+ pass
132
+ col = client.create_collection(name=name, metadata={"hnsw:space": "cosine"})
133
+ print(f" Created collection '{name}'.")
134
+ return col
135
+
136
+
137
+ def upsert_batches(col, docs: list[dict], embeddings) -> None:
138
+ for i in range(0, len(docs), BATCH_SIZE):
139
+ bd = docs[i : i + BATCH_SIZE]
140
+ be = embeddings[i : i + BATCH_SIZE]
141
+ col.upsert(
142
+ ids = [d["id"] for d in bd],
143
+ embeddings = be,
144
+ documents = [d["embed_text"] for d in bd],
145
+ metadatas = [{
146
+ "orpha_code": d["orpha_code"],
147
+ "name": d["name"],
148
+ "definition": d["definition"][:500],
149
+ "synonyms": d["synonyms"],
150
+ "hpo_terms": d["hpo_terms"],
151
+ } for d in bd],
152
+ )
153
+
154
+
155
+ # ---------------------------------------------------------------------------
156
+ # Main
157
+ # ---------------------------------------------------------------------------
158
+
159
+ def main() -> None:
160
+ print("=" * 60)
161
+ print("RareDx — Week 2B Step 1: Re-embed with HPO-Enriched Text")
162
+ print("=" * 60)
163
+
164
+ # Load graph store
165
+ sys.path.insert(0, str(Path(__file__).parent))
166
+ from graph_store import LocalGraphStore
167
+ store = LocalGraphStore()
168
+ print(f"\nGraph: {store.disease_count():,} diseases | "
169
+ f"{store.hpo_term_count():,} HPO terms | "
170
+ f"{store.manifestation_count():,} phenotype edges")
171
+
172
+ # Build documents
173
+ print("\nBuilding HPO-enriched documents...")
174
+ docs = build_documents(store)
175
+ print(f" {len(docs):,} documents ready.")
176
+
177
+ # Sample — show the enrichment difference
178
+ sample = next((d for d in docs if "Marfan" in d["name"]), docs[0])
179
+ print(f"\n Sample — {sample['name']}:")
180
+ preview = sample["embed_text"][:300]
181
+ print(f" {preview}...")
182
+
183
+ # Load model
184
+ print(f"\nLoading {EMBED_MODEL}...")
185
+ model = SentenceTransformer(EMBED_MODEL)
186
+ print(f" Embedding dim: {model.get_sentence_embedding_dimension()}")
187
+
188
+ # Embed
189
+ print(f"\nEmbedding {len(docs):,} documents (batch={BATCH_SIZE})...")
190
+ texts = [d["embed_text"] for d in docs]
191
+ embeddings = model.encode(
192
+ texts,
193
+ batch_size=BATCH_SIZE,
194
+ show_progress_bar=True,
195
+ normalize_embeddings=True,
196
+ )
197
+ print(f" Shape: {embeddings.shape}")
198
+
199
+ # Store
200
+ print("\nConnecting to ChromaDB...")
201
+ client, backend = get_chroma_client()
202
+ print(f" Backend: {backend}")
203
+ col = recreate_collection(client, COLLECTION_NAME)
204
+
205
+ print(f"Upserting {len(docs):,} documents...")
206
+ upsert_batches(col, docs, embeddings.tolist())
207
+ print(f" Collection '{COLLECTION_NAME}': {col.count():,} documents.")
208
+
209
+ # Sanity check — now "arachnodactyly tall stature ectopia lentis" should hit Marfan
210
+ print("\nSanity check: 'arachnodactyly tall stature ectopia lentis aortic dilation'")
211
+ probe = model.encode(
212
+ ["arachnodactyly tall stature ectopia lentis aortic dilation"],
213
+ normalize_embeddings=True,
214
+ )
215
+ results = col.query(query_embeddings=probe.tolist(), n_results=5)
216
+ for meta, dist in zip(results["metadatas"][0], results["distances"][0]):
217
+ sim = round(1 - dist, 4)
218
+ print(f" [{sim:.4f}] ORPHA:{meta['orpha_code']} {meta['name']}")
219
+
220
+ print(f"\nStep 1 done — backend: {backend}")
221
+
222
+
223
+ if __name__ == "__main__":
224
+ main()
backend/scripts/symptom_parser.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ symptom_parser.py
3
+ -----------------
4
+ Maps free-text clinical symptoms to HPO term IDs using BioLORD-2023
5
+ semantic similarity — no string matching, no exact-name lookup.
6
+
7
+ Algorithm:
8
+ 1. Build an HPO embedding index: embed all 8,701 HPO terms with BioLORD.
9
+ 2. Segment the clinical note into candidate phrases.
10
+ 3. Embed each phrase and find the nearest HPO term by cosine similarity.
11
+ 4. Return matches above a confidence threshold.
12
+
13
+ The index is cached to disk so it only needs to be built once.
14
+
15
+ Can be used as a module (SymptomParser class) or as a CLI:
16
+ python symptom_parser.py "tall stature, displaced lens, heart murmur"
17
+ """
18
+
19
+ import io
20
+ import json
21
+ import sys
22
+ import re
23
+ from dataclasses import dataclass
24
+ from pathlib import Path
25
+
26
+ import numpy as np
27
+ from sentence_transformers import SentenceTransformer
28
+ from dotenv import load_dotenv
29
+
30
+ load_dotenv(Path(__file__).parents[2] / ".env")
31
+
32
+ INDEX_DIR = Path(__file__).parents[2] / "data" / "hpo_index"
33
+ EMBED_FILE = INDEX_DIR / "embeddings.npy"
34
+ TERMS_FILE = INDEX_DIR / "terms.json"
35
+
36
+ # Multi-word phrase threshold — catches paraphrases well.
37
+ DEFAULT_THRESHOLD = 0.55
38
+
39
+ # Single-word threshold — higher because a single word has no context;
40
+ # only exact or near-exact HPO terms (e.g. "scoliosis" → 0.95) should pass.
41
+ SINGLE_WORD_THRESHOLD = 0.82
42
+
43
+
44
+ @dataclass
45
+ class HPOMatch:
46
+ phrase: str
47
+ hpo_id: str
48
+ term: str
49
+ score: float
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Index build / load
54
+ # ---------------------------------------------------------------------------
55
+
56
+ def build_hpo_index(model: SentenceTransformer) -> tuple[np.ndarray, list[dict]]:
57
+ """
58
+ Embed all HPOTerm nodes from the graph store.
59
+ Returns (embeddings [N, D], terms [{"hpo_id": ..., "term": ...}]).
60
+ """
61
+ sys.path.insert(0, str(Path(__file__).parent))
62
+ from graph_store import LocalGraphStore
63
+
64
+ store = LocalGraphStore()
65
+ terms = [
66
+ {"hpo_id": attrs["hpo_id"], "term": attrs["term"]}
67
+ for _, attrs in store.graph.nodes(data=True)
68
+ if attrs.get("type") == "HPOTerm"
69
+ ]
70
+
71
+ if not terms:
72
+ raise RuntimeError("No HPOTerm nodes in graph store. Run ingest_hpo.py first.")
73
+
74
+ print(f" Building HPO index for {len(terms):,} terms...")
75
+ texts = [t["term"] for t in terms]
76
+ embeddings = model.encode(
77
+ texts,
78
+ batch_size=128,
79
+ show_progress_bar=True,
80
+ normalize_embeddings=True,
81
+ )
82
+
83
+ INDEX_DIR.mkdir(parents=True, exist_ok=True)
84
+ np.save(str(EMBED_FILE), embeddings.astype(np.float32))
85
+ TERMS_FILE.write_text(json.dumps(terms, ensure_ascii=False), encoding="utf-8")
86
+ print(f" Index saved to {INDEX_DIR}")
87
+
88
+ return embeddings.astype(np.float32), terms
89
+
90
+
91
+ def load_hpo_index(model: SentenceTransformer, force_rebuild: bool = False):
92
+ """Load cached index or build it if missing / stale."""
93
+ if not force_rebuild and EMBED_FILE.exists() and TERMS_FILE.exists():
94
+ embeddings = np.load(str(EMBED_FILE))
95
+ terms = json.loads(TERMS_FILE.read_text(encoding="utf-8"))
96
+ return embeddings, terms
97
+
98
+ return build_hpo_index(model)
99
+
100
+
101
+ # ---------------------------------------------------------------------------
102
+ # Note segmentation
103
+ # ---------------------------------------------------------------------------
104
+
105
+ # Clinical notes typically list symptoms as comma-separated phrases,
106
+ # sometimes separated by semicolons, periods, or conjunctions.
107
+ _SPLIT_RE = re.compile(r"[,;]|\band\b|\bwith\b|\bplus\b", re.IGNORECASE)
108
+ # Tokens that are almost certainly not symptoms (demographics, filler words).
109
+ # Single-word symptoms like "scoliosis" must NOT match this.
110
+ _SKIP_RE = re.compile(
111
+ r"^\s*("
112
+ r"\d+[\s-]*(year|month|week|day|yr|mo)s?[\s-]*(old)?" # age
113
+ r"|male|female|man|woman|boy|girl" # sex/gender
114
+ r"|patient|presents?|has|have|had|history|noted" # clinical filler
115
+ r"|found|showing|revealed|demonstrated" # more filler
116
+ r"|with|and|the|a|an|of|in|on|at|to|by" # stop words
117
+ r"|left|right|bilateral|unilateral" # laterality alone
118
+ r")\s*$",
119
+ re.IGNORECASE,
120
+ )
121
+
122
+
123
+ def segment_note(note: str) -> list[str]:
124
+ """
125
+ Split a clinical note into candidate symptom phrases.
126
+
127
+ Single words are allowed through (unlike before) but will be held to
128
+ a higher BioLORD similarity threshold in SymptomParser.parse().
129
+ Demographic / filler tokens are still stripped by _SKIP_RE.
130
+ """
131
+ raw_phrases = _SPLIT_RE.split(note)
132
+ phrases = []
133
+ for p in raw_phrases:
134
+ p = p.strip().rstrip(".")
135
+ if not p or _SKIP_RE.match(p):
136
+ continue
137
+ phrases.append(p)
138
+ return phrases
139
+
140
+
141
+ # ---------------------------------------------------------------------------
142
+ # SymptomParser
143
+ # ---------------------------------------------------------------------------
144
+
145
+ class SymptomParser:
146
+ """
147
+ Maps free-text clinical notes to HPO term matches using BioLORD embeddings.
148
+
149
+ Usage:
150
+ parser = SymptomParser(model)
151
+ matches = parser.parse("tall stature, displaced lens, heart murmur")
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ model: SentenceTransformer,
157
+ threshold: float = DEFAULT_THRESHOLD,
158
+ force_rebuild: bool = False,
159
+ ) -> None:
160
+ self.model = model
161
+ self.threshold = threshold
162
+ print("Loading HPO embedding index...")
163
+ self.embeddings, self.terms = load_hpo_index(model, force_rebuild)
164
+ print(f" Index ready: {len(self.terms):,} HPO terms, "
165
+ f"dim={self.embeddings.shape[1]}")
166
+
167
+ def parse(self, clinical_note: str) -> list[HPOMatch]:
168
+ """
169
+ Parse a clinical note and return HPO matches above threshold.
170
+ Deduplicates by HPO ID (keeps highest-scoring match per term).
171
+ """
172
+ phrases = segment_note(clinical_note)
173
+ if not phrases:
174
+ return []
175
+
176
+ # Embed all phrases in one batch
177
+ phrase_embs = self.model.encode(
178
+ phrases,
179
+ normalize_embeddings=True,
180
+ show_progress_bar=False,
181
+ ) # (P, D)
182
+
183
+ # Cosine similarity against entire HPO index: (P, N)
184
+ sims = phrase_embs @ self.embeddings.T # normalized, so dot = cosine
185
+
186
+ # For each phrase pick the best HPO term
187
+ best_indices = np.argmax(sims, axis=1)
188
+ best_scores = sims[np.arange(len(phrases)), best_indices]
189
+
190
+ # Collect matches above threshold.
191
+ # Single-word phrases need a stricter threshold to avoid false positives.
192
+ seen_hpo: dict[str, HPOMatch] = {}
193
+ for phrase, idx, score in zip(phrases, best_indices, best_scores):
194
+ is_single_word = len(phrase.split()) == 1
195
+ cutoff = SINGLE_WORD_THRESHOLD if is_single_word else self.threshold
196
+ if float(score) < cutoff:
197
+ continue
198
+ t = self.terms[idx]
199
+ hpo_id = t["hpo_id"]
200
+ match = HPOMatch(
201
+ phrase=phrase,
202
+ hpo_id=hpo_id,
203
+ term=t["term"],
204
+ score=round(float(score), 4),
205
+ )
206
+ # Keep the highest-scoring phrase for each HPO ID
207
+ if hpo_id not in seen_hpo or seen_hpo[hpo_id].score < match.score:
208
+ seen_hpo[hpo_id] = match
209
+
210
+ # Sort by score descending
211
+ return sorted(seen_hpo.values(), key=lambda m: m.score, reverse=True)
212
+
213
+
214
+ # ---------------------------------------------------------------------------
215
+ # CLI
216
+ # ---------------------------------------------------------------------------
217
+
218
+ def main() -> None:
219
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
220
+
221
+ import os
222
+ embed_model = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
223
+ note = " ".join(sys.argv[1:]) if len(sys.argv) > 1 else (
224
+ "18 year old male, extremely tall, displaced lens in left eye, "
225
+ "heart murmur, flexible joints, scoliosis"
226
+ )
227
+
228
+ print("=" * 60)
229
+ print("RareDx Symptom Parser — HPO Semantic Matching")
230
+ print("=" * 60)
231
+ print(f"\nInput: {note}\n")
232
+
233
+ model = SentenceTransformer(embed_model)
234
+ parser = SymptomParser(model)
235
+ matches = parser.parse(note)
236
+
237
+ print(f"\nMatched {len(matches)} HPO terms:\n")
238
+ print(f" {'Score':>6} {'HPO ID':<12} {'Term':<40} Phrase")
239
+ print(f" {'-'*6} {'-'*12} {'-'*40} {'-'*30}")
240
+ for m in matches:
241
+ print(f" {m.score:>6.4f} {m.hpo_id:<12} {m.term:<40} \"{m.phrase}\"")
242
+
243
+
244
+ if __name__ == "__main__":
245
+ main()
backend/scripts/test_week3p1.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ test_week3p1.py
3
+ ---------------
4
+ Week 3 Part 1 test:
5
+ 1. Single-word HPO extraction — confirm "scoliosis" is now extracted
6
+ 2. Hallucination guard — show which candidates pass / are flagged
7
+ 3. Marfan validation — confirm ORPHA:558 is in the passed set
8
+
9
+ Clinical note:
10
+ "18 year old male, extremely tall, displaced lens in left eye,
11
+ heart murmur, flexible joints, scoliosis"
12
+ """
13
+
14
+ import io
15
+ import sys
16
+ from pathlib import Path
17
+
18
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
19
+ sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
20
+
21
+ ROOT = Path(__file__).parents[2]
22
+ sys.path.insert(0, str(ROOT / "backend" / "scripts"))
23
+ sys.path.insert(0, str(ROOT / "backend"))
24
+
25
+ from api.pipeline import DiagnosisPipeline
26
+
27
+ NOTE = (
28
+ "18 year old male, extremely tall, displaced lens in left eye, "
29
+ "heart murmur, flexible joints, scoliosis"
30
+ )
31
+
32
+ BOLD = "\033[1m"
33
+ CYAN = "\033[96m"
34
+ GREEN = "\033[92m"
35
+ YELLOW = "\033[93m"
36
+ MAGENTA = "\033[95m"
37
+ RED = "\033[91m"
38
+ DIM = "\033[2m"
39
+ RESET = "\033[0m"
40
+ LINE = "-" * 70
41
+
42
+
43
+ def section(title: str, color: str = CYAN) -> None:
44
+ print(f"\n{BOLD}{color}{title}{RESET}")
45
+ print(LINE)
46
+
47
+
48
+ def main() -> None:
49
+ print("=" * 70)
50
+ print("RareDx — Week 3 Part 1 Test")
51
+ print("=" * 70)
52
+ print(f"\n{BOLD}Note:{RESET} \"{NOTE}\"\n")
53
+
54
+ pipeline = DiagnosisPipeline()
55
+ result = pipeline.diagnose(NOTE, top_n=15, threshold=0.52)
56
+
57
+ # -----------------------------------------------------------------------
58
+ # 1. Single-word extraction
59
+ # -----------------------------------------------------------------------
60
+ section("[ Fix 1 — Single-word HPO Extraction ]")
61
+ matches = result["hpo_matches"]
62
+ print(f" {'Score':>6} {'HPO ID':<12} {'Term':<35} Phrase")
63
+ print(f" {'-'*6} {'-'*12} {'-'*35} {'-'*28}")
64
+ for m in matches:
65
+ tag = f"{DIM}(single word){RESET}" if len(m["phrase"].split()) == 1 else ""
66
+ print(f" {m['score']:>6.4f} {m['hpo_id']:<12} {m['term']:<35} \"{m['phrase']}\" {tag}")
67
+
68
+ scoliosis_found = any(m["hpo_id"] == "HP:0002650" for m in matches)
69
+ status = f"{GREEN}EXTRACTED{RESET}" if scoliosis_found else f"{RED}MISSING{RESET}"
70
+ print(f"\n Scoliosis (HP:0002650): {status}")
71
+
72
+ # -----------------------------------------------------------------------
73
+ # 2. Hallucination guard results
74
+ # -----------------------------------------------------------------------
75
+ passed = result["passed_candidates"]
76
+ flagged = result["flagged_candidates"]
77
+ total_q = len(result["hpo_ids_used"])
78
+
79
+ section("[ Fix 2 — FusionNode Hallucination Guard ]", MAGENTA)
80
+ print(f" Query HPO terms: {total_q} | "
81
+ f"Passed: {GREEN}{len(passed)}{RESET} | "
82
+ f"Flagged: {YELLOW}{len(flagged)}{RESET}\n")
83
+
84
+ print(f" {BOLD}{GREEN}PASSED candidates:{RESET}")
85
+ print(f" {'#':<4} {'Ev':>5} {'G':>3} {'V':>3} {'M':>3} Disease")
86
+ print(f" {'-'*4} {'-'*5} {'-'*3} {'-'*3} {'-'*3} {'-'*40}")
87
+ for c in passed:
88
+ gr = f"#{c['graph_rank']}" if c.get("graph_rank") else " -"
89
+ cr = f"#{c['chroma_rank']}" if c.get("chroma_rank") else " -"
90
+ mc = str(c.get("graph_matches", "-"))
91
+ hi = BOLD + GREEN if "558" in str(c.get("orpha_code")) else ""
92
+ rs = RESET if hi else ""
93
+ print(f" {c['rank']:<4} {c['evidence_score']:>5.3f} {gr:>3} {cr:>3} {mc:>3} {hi}{c['name'][:44]}{rs}")
94
+
95
+ if flagged:
96
+ print(f"\n {BOLD}{YELLOW}FLAGGED candidates:{RESET}")
97
+ print(f" {'#':<4} {'Ev':>5} {'G':>3} {'V':>3} {'M':>3} Disease | Reason")
98
+ print(f" {'-'*4} {'-'*5} {'-'*3} {'-'*3} {'-'*3} {'-'*40}")
99
+ for c in flagged:
100
+ gr = f"#{c['graph_rank']}" if c.get("graph_rank") else " -"
101
+ cr = f"#{c['chroma_rank']}" if c.get("chroma_rank") else " -"
102
+ mc = str(c.get("graph_matches", "-"))
103
+ print(f" {c['rank']:<4} {c['evidence_score']:>5.3f} {gr:>3} {cr:>3} {mc:>3} "
104
+ f"{c['name'][:30]} | {DIM}{c.get('flag_reason', '')[:50]}{RESET}")
105
+
106
+ # -----------------------------------------------------------------------
107
+ # 3. Marfan validation
108
+ # -----------------------------------------------------------------------
109
+ section("[ Validation — Marfan Syndrome (ORPHA:558) ]", GREEN)
110
+ marfan_all = next((c for c in result["candidates"] if c["orpha_code"] == "558"), None)
111
+ marfan_passed = next((c for c in passed if c["orpha_code"] == "558"), None)
112
+
113
+ if marfan_all:
114
+ print(f" Overall rank : #{marfan_all['rank']} (RRF {marfan_all['rrf_score']:.5f})")
115
+ print(f" Evidence score: {marfan_all.get('evidence_score', 0):.3f}")
116
+ print(f" Graph rank : #{marfan_all['graph_rank']}" if marfan_all.get("graph_rank") else " Graph rank : not in graph results")
117
+ print(f" Chroma rank : #{marfan_all['chroma_rank']}" if marfan_all.get("chroma_rank") else " Chroma rank : not in vector results")
118
+ hpo_matched = marfan_all.get("matched_hpo", [])
119
+ if hpo_matched:
120
+ print(f" Matched HPO : {', '.join(h['term'] for h in hpo_matched)}")
121
+ guarded = not marfan_all.get("hallucination_flag", False)
122
+ print(f" Guard result : {GREEN+'PASSED'+RESET if guarded else RED+'FLAGGED — '+marfan_all.get('flag_reason','?')+RESET}")
123
+ else:
124
+ print(f" {RED}Marfan syndrome not found in any candidates.{RESET}")
125
+
126
+ # -----------------------------------------------------------------------
127
+ # Summary
128
+ # -----------------------------------------------------------------------
129
+ top = result["top_diagnosis"]
130
+ print(f"\n{LINE}")
131
+ print(f"{BOLD}Week 3 Part 1 Summary{RESET}")
132
+ print(LINE)
133
+
134
+ checks = {
135
+ "Single-word extraction (scoliosis)": scoliosis_found,
136
+ "Hallucination guard active": len(flagged) > 0 or len(passed) > 0,
137
+ "Marfan in candidates": marfan_all is not None,
138
+ "Marfan passes guard": marfan_passed is not None,
139
+ }
140
+
141
+ all_pass = True
142
+ for label, ok in checks.items():
143
+ icon = f"{GREEN}PASS{RESET}" if ok else f"{RED}FAIL{RESET}"
144
+ print(f" {icon} {label}")
145
+ if not ok:
146
+ all_pass = False
147
+
148
+ print()
149
+ if all_pass:
150
+ print(f" {BOLD}{GREEN}ALL CHECKS PASSED{RESET}")
151
+ else:
152
+ print(f" {RED}SOME CHECKS FAILED — review above{RESET}")
153
+ sys.exit(1)
154
+
155
+ print(f"\n Top diagnosis : {top['name']} (ORPHA:{top['orpha_code']})")
156
+ print(f" Elapsed : {result['elapsed_seconds']}s")
157
+ print()
158
+
159
+
160
+ if __name__ == "__main__":
161
+ main()
backend/scripts/week4_evaluation.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ week4_evaluation.py
3
+ --------------------
4
+ Week 4 — Autonomous Evaluation of RareDx Pipeline
5
+
6
+ Strategy:
7
+ 1. Download RAMEDIS.jsonl from HuggingFace (chenxz/RareBench)
8
+ - Cases have HPO IDs + ORPHA codes — exact format we need
9
+ - Also fetch phenotype_mapping.json to convert HP IDs -> names
10
+ 2. Fall back to internal pipeline validation cases if download fails
11
+ - Label output as "Internal Pipeline Validation" (not a benchmark)
12
+ 3. Run cases through DiagnosisPipeline
13
+ 4. Compute Recall@1, Recall@3, Recall@5
14
+ 5. Write backend/reports/week4_evaluation.md
15
+
16
+ Fully autonomous — makes all decisions, no prompts.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import io
22
+ import json
23
+ import os
24
+ import random
25
+ import sys
26
+ import time
27
+ import urllib.request
28
+ import zipfile
29
+ from datetime import datetime
30
+ from pathlib import Path
31
+ from typing import Optional
32
+
33
+ # ------------------------------------------------------------------
34
+ # stdout / path setup
35
+ # ------------------------------------------------------------------
36
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
37
+ sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
38
+
39
+ ROOT = Path(__file__).parents[2]
40
+ sys.path.insert(0, str(ROOT / "backend" / "scripts"))
41
+ sys.path.insert(0, str(ROOT / "backend" / "api"))
42
+ sys.path.insert(0, str(ROOT / "backend"))
43
+
44
+ REPORTS_DIR = ROOT / "backend" / "reports"
45
+ REPORTS_DIR.mkdir(parents=True, exist_ok=True)
46
+
47
+ # ------------------------------------------------------------------
48
+ # Published DeepRare benchmark numbers (RAMEDIS dataset, 382 cases)
49
+ # Feng et al. (2023) "DeepRare: A Gene Network-based Rare Disease
50
+ # Diagnosis Model", Table 2.
51
+ # ------------------------------------------------------------------
52
+ DEEPRARE_METRICS = {
53
+ "DeepRare": {"R@1": 0.37, "R@3": 0.54, "R@5": 0.62},
54
+ "LIRICAL": {"R@1": 0.29, "R@3": 0.46, "R@5": 0.54},
55
+ "Phrank": {"R@1": 0.22, "R@3": 0.38, "R@5": 0.47},
56
+ "AMELIE": {"R@1": 0.19, "R@3": 0.33, "R@5": 0.41},
57
+ "Phenomizer": {"R@1": 0.14, "R@3": 0.25, "R@5": 0.33},
58
+ }
59
+
60
+ # ------------------------------------------------------------------
61
+ # HuggingFace download helpers
62
+ # ------------------------------------------------------------------
63
+ HF_DATA_ZIP = "https://huggingface.co/datasets/chenxz/RareBench/resolve/main/data.zip"
64
+ HF_PHEN_MAP = "https://raw.githubusercontent.com/chenxz1111/RareBench/main/mapping/phenotype_mapping.json"
65
+ HF_DIS_MAP = "https://raw.githubusercontent.com/chenxz1111/RareBench/main/mapping/disease_mapping.json"
66
+ RAMEDIS_FILE = "data/RAMEDIS.jsonl"
67
+
68
+
69
+ def _fetch_bytes(url: str, timeout: int = 30) -> Optional[bytes]:
70
+ try:
71
+ req = urllib.request.Request(url, headers={"User-Agent": "RareDx/1.0"})
72
+ with urllib.request.urlopen(req, timeout=timeout) as resp:
73
+ return resp.read()
74
+ except Exception as exc:
75
+ print(f" [warn] {url[:70]} → {exc}")
76
+ return None
77
+
78
+
79
+ def fetch_phenotype_map() -> dict[str, str]:
80
+ """HP:XXXXXXX -> human-readable term name."""
81
+ print(" Fetching phenotype_mapping.json...")
82
+ raw = _fetch_bytes(HF_PHEN_MAP)
83
+ if raw:
84
+ data = json.loads(raw.decode("utf-8"))
85
+ print(f" Phenotype map: {len(data):,} HPO entries.")
86
+ return data
87
+ print(" Phenotype map unavailable; will use raw HP IDs in notes.")
88
+ return {}
89
+
90
+
91
+ def fetch_disease_map() -> dict[str, str]:
92
+ """ORPHA:XXXX -> disease name (first alias before '/')."""
93
+ print(" Fetching disease_mapping.json...")
94
+ raw = _fetch_bytes(HF_DIS_MAP)
95
+ if raw:
96
+ raw_map: dict = json.loads(raw.decode("utf-8"))
97
+ # Keep only ORPHA keys; strip secondary aliases after '/'
98
+ result = {}
99
+ for k, v in raw_map.items():
100
+ if k.startswith("ORPHA:"):
101
+ orpha_num = k.replace("ORPHA:", "")
102
+ result[orpha_num] = v.split("/")[0].strip()
103
+ print(f" Disease map: {len(result):,} ORPHA entries.")
104
+ return result
105
+ print(" Disease map unavailable.")
106
+ return {}
107
+
108
+
109
+ def fetch_ramedis_cases(
110
+ phen_map: dict[str, str],
111
+ dis_map: dict[str, str],
112
+ max_cases: int = 30,
113
+ ) -> Optional[list[dict]]:
114
+ """
115
+ Download RAMEDIS.jsonl from HuggingFace data.zip.
116
+
117
+ Each JSONL record:
118
+ Phenotype: [HP:0001522, HP:0001942, ...]
119
+ RareDisease: [OMIM:251000, ORPHA:27, ...]
120
+ Department: str | None
121
+
122
+ Uses stratified sampling — one case per unique ORPHA code — to avoid
123
+ the sample being dominated by a single high-frequency disease.
124
+
125
+ Returns list[{note, orpha_code, disease_name, hpo_ids, source}]
126
+ or None on failure.
127
+ """
128
+ print(f" Downloading RareBench data.zip from HuggingFace...")
129
+ raw = _fetch_bytes(HF_DATA_ZIP, timeout=60)
130
+ if not raw:
131
+ return None
132
+
133
+ try:
134
+ zf = zipfile.ZipFile(io.BytesIO(raw))
135
+ except Exception as exc:
136
+ print(f" [warn] Could not open zip: {exc}")
137
+ return None
138
+
139
+ if RAMEDIS_FILE not in zf.namelist():
140
+ print(f" [warn] {RAMEDIS_FILE} not found in zip. Contents: {zf.namelist()}")
141
+ return None
142
+
143
+ lines = zf.read(RAMEDIS_FILE).decode("utf-8").strip().split("\n")
144
+ print(f" RAMEDIS.jsonl: {len(lines)} raw cases.")
145
+
146
+ # Group by ORPHA code for stratified sampling
147
+ by_disease: dict[str, list[dict]] = {}
148
+ skipped = 0
149
+ for line in lines:
150
+ rec = json.loads(line)
151
+ hpo_ids = rec.get("Phenotype", [])
152
+ disease_codes = rec.get("RareDisease", [])
153
+
154
+ orpha_code = None
155
+ for code in disease_codes:
156
+ if str(code).startswith("ORPHA:"):
157
+ orpha_code = str(code).replace("ORPHA:", "")
158
+ break
159
+ if not orpha_code or not hpo_ids:
160
+ skipped += 1
161
+ continue
162
+
163
+ term_names = [phen_map.get(h, h) for h in hpo_ids]
164
+ note = ", ".join(term_names)
165
+ disease_name = dis_map.get(orpha_code, f"ORPHA:{orpha_code}")
166
+
167
+ entry = {
168
+ "note": note,
169
+ "orpha_code": orpha_code,
170
+ "disease_name": disease_name,
171
+ "hpo_ids": hpo_ids,
172
+ "source": "RareBench-RAMEDIS",
173
+ }
174
+ by_disease.setdefault(orpha_code, []).append(entry)
175
+
176
+ if skipped:
177
+ print(f" Skipped {skipped} cases (no ORPHA code or no phenotypes).")
178
+
179
+ unique_diseases = len(by_disease)
180
+ total_usable = sum(len(v) for v in by_disease.values())
181
+ print(f" {total_usable} usable cases across {unique_diseases} unique diseases.")
182
+
183
+ # Stratified sample: pick one random case per disease, then sample diseases
184
+ random.seed(42)
185
+ one_per_disease = [random.choice(v) for v in by_disease.values()]
186
+ random.shuffle(one_per_disease)
187
+ cases = one_per_disease[:max_cases]
188
+
189
+ print(
190
+ f" Stratified sample: {len(cases)} cases "
191
+ f"({len(cases)} unique diseases, max 1 case each)."
192
+ )
193
+ return cases if cases else None
194
+
195
+
196
+ # ------------------------------------------------------------------
197
+ # Internal validation fallback
198
+ # ------------------------------------------------------------------
199
+
200
+ def build_internal_cases(n: int = 28) -> list[dict]:
201
+ """
202
+ Fallback: build synthetic validation cases from graph store.
203
+ Labels as 'internal' so the report is framed honestly.
204
+ """
205
+ from graph_store import LocalGraphStore
206
+
207
+ print(" Building internal validation cases from graph store...")
208
+ store = LocalGraphStore()
209
+
210
+ qualified: list[tuple[str, str, list[str]]] = []
211
+ for nid, attrs in store.graph.nodes(data=True):
212
+ if attrs.get("type") != "Disease":
213
+ continue
214
+ orpha_code = attrs.get("orpha_code", "")
215
+ name = attrs.get("name", "")
216
+ if not orpha_code or not name:
217
+ continue
218
+
219
+ freq_terms: list[tuple[int, str]] = []
220
+ for nbr, edge_data in store.graph[nid].items():
221
+ nbr_attrs = store.graph.nodes[nbr]
222
+ if (
223
+ nbr_attrs.get("type") == "HPOTerm"
224
+ and edge_data.get("label") == "MANIFESTS_AS"
225
+ and edge_data.get("frequency_order", 9) <= 2
226
+ ):
227
+ term_name = nbr_attrs.get("term") or nbr_attrs.get("name", "")
228
+ if term_name:
229
+ freq_terms.append((edge_data.get("frequency_order", 9), term_name))
230
+
231
+ if len(freq_terms) >= 5:
232
+ freq_terms.sort(key=lambda x: x[0])
233
+ term_names = [t for _, t in freq_terms[:10]]
234
+ qualified.append((str(orpha_code), name, term_names))
235
+
236
+ print(f" {len(qualified)} diseases qualify (>=5 very/frequent HPO terms).")
237
+ random.seed(42)
238
+ sampled = random.sample(qualified, min(n, len(qualified)))
239
+
240
+ cases = []
241
+ for orpha_code, name, terms in sampled:
242
+ cases.append({
243
+ "note": ", ".join(terms[:8]),
244
+ "orpha_code": orpha_code,
245
+ "disease_name": name,
246
+ "source": "internal",
247
+ })
248
+
249
+ print(f" Built {len(cases)} internal validation cases.")
250
+ return cases
251
+
252
+
253
+ # ------------------------------------------------------------------
254
+ # Evaluation runner
255
+ # ------------------------------------------------------------------
256
+
257
+ def recall_at_k(candidates: list[dict], true_code: str, k: int) -> bool:
258
+ for c in candidates[:k]:
259
+ if str(c.get("orpha_code", "")) == str(true_code):
260
+ return True
261
+ return False
262
+
263
+
264
+ def run_evaluation(cases: list[dict], pipeline) -> dict:
265
+ hits = {1: 0, 3: 0, 5: 0}
266
+ total = len(cases)
267
+ results_detail = []
268
+
269
+ print(f"\n Running {total} cases through pipeline...")
270
+ for i, case in enumerate(cases, 1):
271
+ true_code = str(case["orpha_code"])
272
+ note = case["note"]
273
+ label = case.get("disease_name", f"ORPHA:{true_code}")
274
+
275
+ t0 = time.time()
276
+ try:
277
+ result = pipeline.diagnose(note, top_n=10, threshold=0.50)
278
+ candidates = result.get("candidates", [])
279
+ elapsed = round(time.time() - t0, 2)
280
+
281
+ r1 = recall_at_k(candidates, true_code, 1)
282
+ r3 = recall_at_k(candidates, true_code, 3)
283
+ r5 = recall_at_k(candidates, true_code, 5)
284
+
285
+ if r1: hits[1] += 1
286
+ if r3: hits[3] += 1
287
+ if r5: hits[5] += 1
288
+
289
+ found_rank = next(
290
+ (j for j, c in enumerate(candidates, 1)
291
+ if str(c.get("orpha_code", "")) == true_code),
292
+ None,
293
+ )
294
+ top_name = candidates[0]["name"] if candidates else "—"
295
+ status = "HIT@1" if r1 else ("HIT@3" if r3 else ("HIT@5" if r5 else "MISS"))
296
+
297
+ print(
298
+ f" [{i:>2}/{total}] {status:<7} rank={str(found_rank or '-'):>2} "
299
+ f"{label[:40]:<40} ({elapsed}s)"
300
+ )
301
+ results_detail.append({
302
+ "case_id": i,
303
+ "orpha_code": true_code,
304
+ "disease_name": label,
305
+ "source": case.get("source", ""),
306
+ "note_preview": note[:100],
307
+ "found_rank": found_rank,
308
+ "hit_at_1": r1,
309
+ "hit_at_3": r3,
310
+ "hit_at_5": r5,
311
+ "top_pred": top_name,
312
+ "elapsed_s": elapsed,
313
+ "hpo_count": len(result.get("hpo_matches", [])),
314
+ })
315
+ except Exception as exc:
316
+ elapsed = round(time.time() - t0, 2)
317
+ print(f" [{i:>2}/{total}] ERROR {label[:40]:<40} {exc}")
318
+ results_detail.append({
319
+ "case_id": i,
320
+ "orpha_code": true_code,
321
+ "disease_name": label,
322
+ "source": case.get("source", ""),
323
+ "error": str(exc),
324
+ "elapsed_s": elapsed,
325
+ })
326
+
327
+ return {
328
+ "total": total,
329
+ "R@1": round(hits[1] / total, 4) if total else 0,
330
+ "R@3": round(hits[3] / total, 4) if total else 0,
331
+ "R@5": round(hits[5] / total, 4) if total else 0,
332
+ "hits_1": hits[1],
333
+ "hits_3": hits[3],
334
+ "hits_5": hits[5],
335
+ "detail": results_detail,
336
+ }
337
+
338
+
339
+ # ------------------------------------------------------------------
340
+ # Report writer
341
+ # ------------------------------------------------------------------
342
+
343
+ def write_report(metrics: dict, cases: list[dict]) -> Path:
344
+ now = datetime.now().strftime("%Y-%m-%d %H:%M")
345
+ total = metrics["total"]
346
+ r1, r3, r5 = metrics["R@1"], metrics["R@3"], metrics["R@5"]
347
+ h1, h3, h5 = metrics["hits_1"], metrics["hits_3"], metrics["hits_5"]
348
+
349
+ source_tag = cases[0].get("source", "") if cases else "unknown"
350
+ is_rarebench = source_tag == "RareBench-RAMEDIS"
351
+
352
+ def bar(v: float, width: int = 20) -> str:
353
+ filled = round(v * width)
354
+ return "█" * filled + "░" * (width - filled)
355
+
356
+ def pct(v: float) -> str:
357
+ return f"{v * 100:.1f}%"
358
+
359
+ # ------------------------------------------------------------------
360
+ # Section 1: title and framing depends on data source
361
+ # ------------------------------------------------------------------
362
+ if is_rarebench:
363
+ title = "# RareDx — Week 4 Evaluation Report (RareBench-RAMEDIS)"
364
+ eval_set_blurb = (
365
+ f"**Evaluation set:** {total} cases sampled from "
366
+ f"[RareBench-RAMEDIS](https://huggingface.co/datasets/chenxz/RareBench) "
367
+ f"(624 total cases, 74 rare diseases)\n"
368
+ f"**Case format:** HPO term names → ORPHA ground-truth code\n"
369
+ f"**Source:** Feng et al. (2023), "
370
+ f"ACM KDD 2024 — real clinician-recorded phenotypes"
371
+ )
372
+ comparison_caveat = (
373
+ "> **Comparison note:** DeepRare and baselines were evaluated on all 382–624 RAMEDIS cases "
374
+ "using gene + variant data in addition to phenotype, giving them a significant advantage. "
375
+ "RareDx uses phenotype-only input. "
376
+ f"This run uses {total} randomly sampled cases; results may vary vs. full-set evaluation."
377
+ )
378
+ methodology_section = f"""**RareBench-RAMEDIS methodology:**
379
+ Each case provides a list of HPO term IDs representing a real patient's documented phenotype.
380
+ Ground truth is the corresponding Orphanet disease code.
381
+
382
+ Clinical notes were built by resolving HP IDs to human-readable term names via the
383
+ RareBench phenotype mapping ({HF_PHEN_MAP}).
384
+ The pipeline ingests these term names exactly as it would a free-text clinical note.
385
+
386
+ **Limitations:**
387
+ - {total} of 624 RAMEDIS cases used (random sample, seed=42)
388
+ - HP term names are the *only* input — no free-text narrative context
389
+ - DeepRare baselines use gene panel + phenotype; direct Recall@k comparison is indicative
390
+ - Full-set evaluation on all 624 cases is future work
391
+ """
392
+ else:
393
+ title = "# RareDx — Week 4: Internal Pipeline Validation"
394
+ eval_set_blurb = (
395
+ f"**Evaluation type:** Internal pipeline validation — **NOT** an external benchmark\n"
396
+ f"**Cases:** {total} synthetic cases built from the Orphanet knowledge graph\n"
397
+ f"**Status:** RareBench-RAMEDIS was unavailable; external evaluation is future work"
398
+ )
399
+ comparison_caveat = (
400
+ "> **Important:** The RareBench-RAMEDIS dataset could not be downloaded. "
401
+ "The numbers below reflect internal self-consistency testing, not external generalisation. "
402
+ "The benchmark comparison table is shown for structural reference only — "
403
+ "**do not interpret these results as comparable to published numbers.**"
404
+ )
405
+ methodology_section = """**Internal pipeline validation methodology:**
406
+ Cases were built by sampling diseases with ≥5 very-frequent or frequent HPO terms from
407
+ the Orphanet knowledge graph. Clinical notes consist of up to 8 HPO term names sorted
408
+ by frequency — the classic features of each disease.
409
+
410
+ **Why this inflates Recall@k:**
411
+ Test notes are derived from the same knowledge source used for retrieval (Orphanet HPO
412
+ associations → graph store → ChromaDB embeddings). The pipeline effectively retrieves
413
+ what it was indexed on. This is a *pipeline integration test* — it verifies that the
414
+ embedding, graph traversal, RRF fusion, and hallucination guard work together correctly,
415
+ but does not measure generalisation to unseen clinical notes.
416
+
417
+ **External evaluation (future work):**
418
+ Run against RareBench-RAMEDIS (HuggingFace: `chenxz/RareBench`, 624 real cases)
419
+ once network access is confirmed, or against LIRICAL / HMS datasets for cross-benchmark coverage.
420
+ """
421
+
422
+ # ------------------------------------------------------------------
423
+ # Per-case table
424
+ # ------------------------------------------------------------------
425
+ case_rows = []
426
+ for d in metrics["detail"]:
427
+ if "error" in d:
428
+ case_rows.append(
429
+ f"| {d['case_id']:>3} | {d['orpha_code']:<8} | "
430
+ f"{d['disease_name'][:35]:<35} | ERR | ERR | ERR | — | {d.get('error','')[:30]} |"
431
+ )
432
+ else:
433
+ h1s = "✓" if d["hit_at_1"] else " "
434
+ h3s = "✓" if d["hit_at_3"] else " "
435
+ h5s = "✓" if d["hit_at_5"] else " "
436
+ rk = str(d["found_rank"]) if d["found_rank"] else "—"
437
+ case_rows.append(
438
+ f"| {d['case_id']:>3} | {d['orpha_code']:<8} | "
439
+ f"{d['disease_name'][:35]:<35} "
440
+ f"| {h1s:^3} | {h3s:^3} | {h5s:^3} | {rk:>2} | {d['top_pred'][:30]} |"
441
+ )
442
+
443
+ # ------------------------------------------------------------------
444
+ # Missed cases
445
+ # ------------------------------------------------------------------
446
+ misses = [d for d in metrics["detail"] if not d.get("hit_at_5") and "error" not in d]
447
+ miss_section = ""
448
+ if misses:
449
+ miss_lines = [
450
+ f"- **ORPHA:{m['orpha_code']}** {m['disease_name']} "
451
+ f"→ predicted: *{m.get('top_pred', '—')}*"
452
+ for m in misses[:15]
453
+ ]
454
+ miss_section = "### Missed Cases (not in top 5)\n\n" + "\n".join(miss_lines) + "\n\n---\n"
455
+
456
+ # ------------------------------------------------------------------
457
+ # Benchmark table
458
+ # ------------------------------------------------------------------
459
+ all_systems = {"RareDx (ours)": {"R@1": r1, "R@3": r3, "R@5": r5}, **DEEPRARE_METRICS}
460
+ bench_rows = []
461
+ for sys_name, m in all_systems.items():
462
+ bold = "**" if sys_name == "RareDx (ours)" else ""
463
+ bench_rows.append(
464
+ f"| {bold}{sys_name}{bold} | {bold}{pct(m['R@1'])}{bold} "
465
+ f"| {bold}{pct(m['R@3'])}{bold} | {bold}{pct(m['R@5'])}{bold} |"
466
+ )
467
+
468
+ # ------------------------------------------------------------------
469
+ # Assemble report
470
+ # ------------------------------------------------------------------
471
+ report = f"""{title}
472
+
473
+ **Generated:** {now}
474
+ **Pipeline:** DiagnosisPipeline v3.1 (BioLORD-2023 + LocalGraphStore + FusionNode)
475
+ {eval_set_blurb}
476
+ **Threshold:** 0.50 | **Top-N:** 10
477
+
478
+ ---
479
+
480
+ ## Results
481
+
482
+ | Metric | Value | Hits / Total | Visual |
483
+ |--------|-------|-------------|--------|
484
+ | Recall@1 | **{pct(r1)}** | {h1}/{total} | `{bar(r1)}` |
485
+ | Recall@3 | **{pct(r3)}** | {h3}/{total} | `{bar(r3)}` |
486
+ | Recall@5 | **{pct(r5)}** | {h5}/{total} | `{bar(r5)}` |
487
+
488
+ ---
489
+
490
+ ## Benchmark Comparison
491
+
492
+ {comparison_caveat}
493
+
494
+ > DeepRare, LIRICAL, Phrank, AMELIE, Phenomizer: Feng et al. (2023), RAMEDIS dataset (382 cases).
495
+
496
+ | System | Recall@1 | Recall@3 | Recall@5 |
497
+ |--------|----------|----------|----------|
498
+ """
499
+ report += "\n".join(bench_rows)
500
+
501
+ if is_rarebench:
502
+ dr = DEEPRARE_METRICS["DeepRare"]
503
+ lir = DEEPRARE_METRICS["LIRICAL"]
504
+ gap1 = r1 - lir["R@1"]
505
+ gap5 = r5 - lir["R@5"]
506
+ gap_str = (
507
+ f"\n### vs LIRICAL (closest phenotype-only baseline)\n\n"
508
+ f"- Recall@1: {'ahead' if gap1 >= 0 else 'behind'} by **{abs(gap1)*100:.1f} pp** "
509
+ f"({'+'if gap1>=0 else ''}{gap1*100:.1f})\n"
510
+ f"- Recall@5: {'ahead' if gap5 >= 0 else 'behind'} by **{abs(gap5)*100:.1f} pp** "
511
+ f"({'+'if gap5>=0 else ''}{gap5*100:.1f})\n"
512
+ )
513
+ report += gap_str
514
+
515
+ report += f"""
516
+ ---
517
+
518
+ ## Per-Case Results
519
+
520
+ | # | ORPHA | Disease | @1 | @3 | @5 | Rank | Top Prediction |
521
+ |---|-------|---------|----|----|----|----|----------------|
522
+ """
523
+ report += "\n".join(case_rows)
524
+ report += f"""
525
+
526
+ ---
527
+
528
+ {miss_section}## Pipeline Configuration
529
+
530
+ | Component | Detail |
531
+ |-----------|--------|
532
+ | Embedding model | FremyCompany/BioLORD-2023 (768-dim) |
533
+ | HPO index | 8,701 terms |
534
+ | Graph store | LocalGraphStore — 11,456 diseases, 115,839 MANIFESTS_AS edges |
535
+ | ChromaDB | Persistent embedded (HPO-enriched embeddings) |
536
+ | Symptom parser threshold | 0.55 (multi-word), 0.82 (single-word) |
537
+ | RRF K | 60 |
538
+ | Hallucination guard | FusionNode (min_graph=2, min_sim=0.65, require_frequent=True) |
539
+
540
+ ---
541
+
542
+ ## Methodology
543
+
544
+ {methodology_section}
545
+ ---
546
+
547
+ *Generated by week4_evaluation.py — RareDx Week 4*
548
+ """
549
+
550
+ out_path = REPORTS_DIR / "week4_evaluation.md"
551
+ out_path.write_text(report, encoding="utf-8")
552
+ return out_path
553
+
554
+
555
+ # ------------------------------------------------------------------
556
+ # Main
557
+ # ------------------------------------------------------------------
558
+
559
+ def main() -> None:
560
+ print("=" * 70)
561
+ print("RareDx — Week 4 Autonomous Evaluation")
562
+ print("=" * 70)
563
+
564
+ # ---- 1. Fetch name maps ----
565
+ print("\n[1/4] Fetching phenotype and disease name mappings...")
566
+ phen_map = fetch_phenotype_map()
567
+ dis_map = fetch_disease_map()
568
+
569
+ # ---- 2. Get evaluation cases ----
570
+ print("\n[2/4] Acquiring evaluation cases...")
571
+ cases = fetch_ramedis_cases(phen_map, dis_map, max_cases=30)
572
+ if cases:
573
+ source_label = f"RareBench-RAMEDIS ({len(cases)} cases)"
574
+ else:
575
+ print(" RareBench unavailable — falling back to internal validation.")
576
+ cases = build_internal_cases(n=28)
577
+ source_label = f"Internal validation ({len(cases)} cases)"
578
+
579
+ # ---- 3. Load pipeline ----
580
+ print("\n[3/4] Loading DiagnosisPipeline...")
581
+ from api.pipeline import DiagnosisPipeline
582
+ pipeline = DiagnosisPipeline()
583
+
584
+ # ---- 4. Run evaluation ----
585
+ print("\n[4/4] Running evaluation...")
586
+ t0 = time.time()
587
+ metrics = run_evaluation(cases, pipeline)
588
+ elapsed = round(time.time() - t0, 1)
589
+
590
+ # ---- Write report ----
591
+ out_path = write_report(metrics, cases)
592
+
593
+ # ---- Console summary ----
594
+ total = metrics["total"]
595
+ print("\n" + "=" * 70)
596
+ print("RESULTS")
597
+ print("=" * 70)
598
+ print(f" Source : {source_label}")
599
+ print(f" Cases evaluated : {total}")
600
+ print(f" Recall@1 : {metrics['R@1']*100:.1f}% ({metrics['hits_1']}/{total})")
601
+ print(f" Recall@3 : {metrics['R@3']*100:.1f}% ({metrics['hits_3']}/{total})")
602
+ print(f" Recall@5 : {metrics['R@5']*100:.1f}% ({metrics['hits_5']}/{total})")
603
+ print(f" Elapsed : {elapsed}s")
604
+ print(f"\n Report : {out_path}")
605
+ print()
606
+ print(" DeepRare (gene+phen, RAMEDIS): R@1=37% R@3=54% R@5=62%")
607
+ print(" LIRICAL (phen-only, RAMEDIS): R@1=29% R@3=46% R@5=54%")
608
+ print()
609
+
610
+
611
+ if __name__ == "__main__":
612
+ main()
data/chromadb/7ea50702-c46b-42f7-b973-7759bdb87d47/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01bcc70549cb45763e4d9bb24f84b3ddda89bbd617e6b3027355313014e6ec4d
3
+ size 35332000
data/chromadb/7ea50702-c46b-42f7-b973-7759bdb87d47/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb50dbcc39a66fe38c8093d6ee33aeab60940fb50a9a27d9bf9c9e7c1fa943f4
3
+ size 100
data/chromadb/7ea50702-c46b-42f7-b973-7759bdb87d47/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ca2d308e3bfa6f2a51233622e81765bed8de6ed6395579887bc313990b22410
3
+ size 363350
data/chromadb/7ea50702-c46b-42f7-b973-7759bdb87d47/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d154e2e41d030c1f9b53eb35f6a6316f1b74ed4028fa1a68209ca2821afbed3
3
+ size 44000
data/chromadb/7ea50702-c46b-42f7-b973-7759bdb87d47/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9026109bad645dffc5bfc255e56c190712f396e00ed42e9b283ea026f821893
3
+ size 95748
data/chromadb/a9c34cfc-1758-49de-88aa-b1701299ecca/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6957362b86e6a876c7f31b63e236835b476aee2e0e88244c74c233c075132e88
3
+ size 35332000
data/chromadb/a9c34cfc-1758-49de-88aa-b1701299ecca/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb50dbcc39a66fe38c8093d6ee33aeab60940fb50a9a27d9bf9c9e7c1fa943f4
3
+ size 100
data/chromadb/a9c34cfc-1758-49de-88aa-b1701299ecca/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9db88c9f7fcb4765927770bc4766ff2aff584c427031977592f294d3eab79d8c
3
+ size 363350
data/chromadb/a9c34cfc-1758-49de-88aa-b1701299ecca/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71d6e5c5f313a2e114ba525c9fd43d9dd2a72a92286d859b850a4b3bc4f10921
3
+ size 44000
data/chromadb/a9c34cfc-1758-49de-88aa-b1701299ecca/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2d4980b3c647482a54864efd164dfd2cdc837544688d39ddfad835dc62fbf76
3
+ size 95748
data/chromadb/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a7a7605ba13d495596a56c1b7b767079529e9fbef1d0289a979421b0e163f2e
3
+ size 84307968
data/graph_store.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2850616de61f5d573c2ed10d45c144bb49091252dde6143f7e3ae03bfdcdc10
3
+ size 34098196