test: update tests for evidence-linked mock data and new features
Browse files- Update NCT ID references from NCT04000xxx to MOCK-NCT-* across all tests
- Update MCP test mocks from streaming to regular POST response
- Add live test fixtures (Gemini, MCP, MedGemma) in conftest.py
- Load .env in conftest for live test API key access
- Test "would match IF" phrasing in gap analysis page
- Test patient profile inclusion in MedGemma criterion evaluation
- Patch direct_pipeline in profile review advance test
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- app/tests/test_gap_analysis_page.py +11 -1
- app/tests/test_profile_review_page.py +12 -2
- app/tests/test_trial_matching_page.py +2 -2
- conftest.py +47 -0
- tests/test_integration.py +6 -6
- trialpath/tests/test_mcp.py +10 -17
- trialpath/tests/test_medgemma.py +15 -2
app/tests/test_gap_analysis_page.py
CHANGED
|
@@ -30,8 +30,11 @@ def test_page_renders_without_error(gap_app):
|
|
| 30 |
|
| 31 |
|
| 32 |
def test_displays_gaps(gap_app):
|
|
|
|
|
|
|
| 33 |
all_md = " ".join(str(m.value) for m in gap_app.markdown)
|
| 34 |
-
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
def test_displays_recommended_actions(gap_app):
|
|
@@ -39,6 +42,13 @@ def test_displays_recommended_actions(gap_app):
|
|
| 39 |
assert "upload" in all_md.lower() or "request" in all_md.lower()
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def test_has_summary_button(gap_app):
|
| 43 |
labels = [str(b.label) for b in gap_app.button]
|
| 44 |
assert any("summary" in lbl.lower() or "generate" in lbl.lower() for lbl in labels)
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
def test_displays_gaps(gap_app):
|
| 33 |
+
expander_labels = [str(e.label) for e in gap_app.expander]
|
| 34 |
+
all_labels = " ".join(expander_labels)
|
| 35 |
all_md = " ".join(str(m.value) for m in gap_app.markdown)
|
| 36 |
+
combined = all_labels + " " + all_md
|
| 37 |
+
assert "Brain MRI" in combined or "EGFR" in combined
|
| 38 |
|
| 39 |
|
| 40 |
def test_displays_recommended_actions(gap_app):
|
|
|
|
| 42 |
assert "upload" in all_md.lower() or "request" in all_md.lower()
|
| 43 |
|
| 44 |
|
| 45 |
+
def test_displays_would_match_phrasing(gap_app):
|
| 46 |
+
"""PRD core value proposition: 'You would match [trial] IF you had' phrasing."""
|
| 47 |
+
all_md = " ".join(str(m.value) for m in gap_app.markdown)
|
| 48 |
+
assert "You would match" in all_md
|
| 49 |
+
assert "IF you had" in all_md
|
| 50 |
+
|
| 51 |
+
|
| 52 |
def test_has_summary_button(gap_app):
|
| 53 |
labels = [str(b.label) for b in gap_app.button]
|
| 54 |
assert any("summary" in lbl.lower() or "generate" in lbl.lower() for lbl in labels)
|
app/tests/test_profile_review_page.py
CHANGED
|
@@ -1,9 +1,15 @@
|
|
| 1 |
"""Tests for app/pages/2_profile_review.py — PRESCREEN state."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
import pytest
|
| 4 |
from streamlit.testing.v1 import AppTest
|
| 5 |
|
| 6 |
-
from app.services.mock_data import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
@pytest.fixture
|
|
@@ -52,7 +58,11 @@ def test_has_confirm_button(profile_app):
|
|
| 52 |
assert any("confirm" in lbl.lower() or "search" in lbl.lower() for lbl in labels)
|
| 53 |
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
confirm_btns = [
|
| 57 |
b
|
| 58 |
for b in profile_app.button
|
|
|
|
| 1 |
"""Tests for app/pages/2_profile_review.py — PRESCREEN state."""
|
| 2 |
|
| 3 |
+
from unittest.mock import patch
|
| 4 |
+
|
| 5 |
import pytest
|
| 6 |
from streamlit.testing.v1 import AppTest
|
| 7 |
|
| 8 |
+
from app.services.mock_data import (
|
| 9 |
+
MOCK_ELIGIBILITY_LEDGERS,
|
| 10 |
+
MOCK_PATIENT_PROFILE,
|
| 11 |
+
MOCK_TRIAL_CANDIDATES,
|
| 12 |
+
)
|
| 13 |
|
| 14 |
|
| 15 |
@pytest.fixture
|
|
|
|
| 58 |
assert any("confirm" in lbl.lower() or "search" in lbl.lower() for lbl in labels)
|
| 59 |
|
| 60 |
|
| 61 |
+
@patch(
|
| 62 |
+
"app.services.direct_pipeline.run_trial_search_and_evaluate",
|
| 63 |
+
return_value=(MOCK_TRIAL_CANDIDATES, MOCK_ELIGIBILITY_LEDGERS),
|
| 64 |
+
)
|
| 65 |
+
def test_confirm_advances_to_validate_trials(mock_pipeline, profile_app):
|
| 66 |
confirm_btns = [
|
| 67 |
b
|
| 68 |
for b in profile_app.button
|
app/tests/test_trial_matching_page.py
CHANGED
|
@@ -38,8 +38,8 @@ def test_displays_trial_nct_ids(matching_app):
|
|
| 38 |
# NCT IDs are in expander labels, not in markdown body text
|
| 39 |
expander_labels = [str(e.label) for e in matching_app.expander]
|
| 40 |
all_labels = " ".join(expander_labels)
|
| 41 |
-
assert "
|
| 42 |
-
assert "
|
| 43 |
|
| 44 |
|
| 45 |
def test_displays_traffic_light_colors(matching_app):
|
|
|
|
| 38 |
# NCT IDs are in expander labels, not in markdown body text
|
| 39 |
expander_labels = [str(e.label) for e in matching_app.expander]
|
| 40 |
all_labels = " ".join(expander_labels)
|
| 41 |
+
assert "MOCK-NCT-KEYNOTE999" in all_labels
|
| 42 |
+
assert "MOCK-NCT-FLAURA2" in all_labels
|
| 43 |
|
| 44 |
|
| 45 |
def test_displays_traffic_light_colors(matching_app):
|
conftest.py
CHANGED
|
@@ -2,10 +2,18 @@
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 5 |
from unittest.mock import AsyncMock, MagicMock, patch
|
| 6 |
|
| 7 |
import pytest
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from app.services.mock_data import (
|
| 10 |
MOCK_ELIGIBILITY_LEDGERS,
|
| 11 |
MOCK_PATIENT_PROFILE,
|
|
@@ -146,3 +154,42 @@ def mock_mcp():
|
|
| 146 |
instance.get_study.return_value = MOCK_TRIAL_CANDIDATES[0].model_dump()
|
| 147 |
cls.return_value = instance
|
| 148 |
yield instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
import os
|
| 6 |
from unittest.mock import AsyncMock, MagicMock, patch
|
| 7 |
|
| 8 |
import pytest
|
| 9 |
|
| 10 |
+
try:
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
+
except ImportError:
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
from app.services.mock_data import (
|
| 18 |
MOCK_ELIGIBILITY_LEDGERS,
|
| 19 |
MOCK_PATIENT_PROFILE,
|
|
|
|
| 154 |
instance.get_study.return_value = MOCK_TRIAL_CANDIDATES[0].model_dump()
|
| 155 |
cls.return_value = instance
|
| 156 |
yield instance
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ---------------------------------------------------------------------------
|
| 160 |
+
# Live service fixtures (require real API keys / running servers)
|
| 161 |
+
# ---------------------------------------------------------------------------
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@pytest.fixture(scope="session")
|
| 165 |
+
def live_env():
|
| 166 |
+
"""Ensure env vars are loaded; skip the entire session block if missing."""
|
| 167 |
+
if not os.environ.get("GEMINI_API_KEY"):
|
| 168 |
+
pytest.skip("GEMINI_API_KEY not set — skipping live tests")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@pytest.fixture(scope="session")
|
| 172 |
+
def live_gemini(live_env):
|
| 173 |
+
"""Return a real GeminiPlanner wired to the Gemini API."""
|
| 174 |
+
from trialpath.services.gemini_planner import GeminiPlanner
|
| 175 |
+
|
| 176 |
+
return GeminiPlanner()
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@pytest.fixture(scope="session")
|
| 180 |
+
def live_mcp_client(live_env):
|
| 181 |
+
"""Return a real ClinicalTrialsMCPClient."""
|
| 182 |
+
from trialpath.services.mcp_client import ClinicalTrialsMCPClient
|
| 183 |
+
|
| 184 |
+
return ClinicalTrialsMCPClient()
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@pytest.fixture(scope="session")
|
| 188 |
+
def live_medgemma(live_env):
|
| 189 |
+
"""Return a real MedGemmaExtractor (skip if no HF_TOKEN)."""
|
| 190 |
+
if not os.environ.get("HF_TOKEN"):
|
| 191 |
+
pytest.skip("HF_TOKEN not set — skipping MedGemma live tests")
|
| 192 |
+
|
| 193 |
+
from trialpath.services.medgemma_extractor import MedGemmaExtractor
|
| 194 |
+
|
| 195 |
+
return MedGemmaExtractor()
|
tests/test_integration.py
CHANGED
|
@@ -72,15 +72,15 @@ class TestComponentModelIntegration:
|
|
| 72 |
assert len(spec["biomarkers"]) == 3
|
| 73 |
|
| 74 |
def test_trial_card_renders_green_trial(self):
|
| 75 |
-
#
|
| 76 |
trial = MOCK_TRIAL_CANDIDATES[1]
|
| 77 |
ledger = MOCK_ELIGIBILITY_LEDGERS[1]
|
| 78 |
spec = render_trial_card(trial, ledger)
|
| 79 |
assert spec["traffic_light"] == "green"
|
| 80 |
-
assert spec["nct_id"] == "
|
| 81 |
|
| 82 |
def test_trial_card_renders_yellow_trial(self):
|
| 83 |
-
#
|
| 84 |
trial = MOCK_TRIAL_CANDIDATES[0]
|
| 85 |
ledger = MOCK_ELIGIBILITY_LEDGERS[0]
|
| 86 |
spec = render_trial_card(trial, ledger)
|
|
@@ -88,7 +88,7 @@ class TestComponentModelIntegration:
|
|
| 88 |
assert len(spec["gaps"]) == 1
|
| 89 |
|
| 90 |
def test_trial_card_renders_red_trial(self):
|
| 91 |
-
#
|
| 92 |
trial = MOCK_TRIAL_CANDIDATES[2]
|
| 93 |
ledger = MOCK_ELIGIBILITY_LEDGERS[2]
|
| 94 |
spec = render_trial_card(trial, ledger)
|
|
@@ -97,7 +97,7 @@ class TestComponentModelIntegration:
|
|
| 97 |
def test_gap_card_renders_from_ledger_gap(self):
|
| 98 |
ledger = MOCK_ELIGIBILITY_LEDGERS[0]
|
| 99 |
gap = ledger.gaps[0]
|
| 100 |
-
spec = render_gap_card(gap, affected_trials=["
|
| 101 |
assert "Brain MRI" in spec["description"]
|
| 102 |
assert spec["importance_color"] == "red" # high importance
|
| 103 |
|
|
@@ -153,7 +153,7 @@ class TestDoctorPacketGeneration:
|
|
| 153 |
def test_all_trial_nct_ids_in_packet(self):
|
| 154 |
ledgers = MOCK_ELIGIBILITY_LEDGERS
|
| 155 |
packet_ids = [lg.nct_id for lg in ledgers]
|
| 156 |
-
expected_ids = ["
|
| 157 |
assert packet_ids == expected_ids
|
| 158 |
|
| 159 |
|
|
|
|
| 72 |
assert len(spec["biomarkers"]) == 3
|
| 73 |
|
| 74 |
def test_trial_card_renders_green_trial(self):
|
| 75 |
+
# MOCK-NCT-FLAURA2 is LIKELY_ELIGIBLE -> green
|
| 76 |
trial = MOCK_TRIAL_CANDIDATES[1]
|
| 77 |
ledger = MOCK_ELIGIBILITY_LEDGERS[1]
|
| 78 |
spec = render_trial_card(trial, ledger)
|
| 79 |
assert spec["traffic_light"] == "green"
|
| 80 |
+
assert spec["nct_id"] == "MOCK-NCT-FLAURA2"
|
| 81 |
|
| 82 |
def test_trial_card_renders_yellow_trial(self):
|
| 83 |
+
# MOCK-NCT-KEYNOTE999 is UNCERTAIN -> yellow
|
| 84 |
trial = MOCK_TRIAL_CANDIDATES[0]
|
| 85 |
ledger = MOCK_ELIGIBILITY_LEDGERS[0]
|
| 86 |
spec = render_trial_card(trial, ledger)
|
|
|
|
| 88 |
assert len(spec["gaps"]) == 1
|
| 89 |
|
| 90 |
def test_trial_card_renders_red_trial(self):
|
| 91 |
+
# MOCK-NCT-CM817 is LIKELY_INELIGIBLE -> red
|
| 92 |
trial = MOCK_TRIAL_CANDIDATES[2]
|
| 93 |
ledger = MOCK_ELIGIBILITY_LEDGERS[2]
|
| 94 |
spec = render_trial_card(trial, ledger)
|
|
|
|
| 97 |
def test_gap_card_renders_from_ledger_gap(self):
|
| 98 |
ledger = MOCK_ELIGIBILITY_LEDGERS[0]
|
| 99 |
gap = ledger.gaps[0]
|
| 100 |
+
spec = render_gap_card(gap, affected_trials=["MOCK-NCT-KEYNOTE999"])
|
| 101 |
assert "Brain MRI" in spec["description"]
|
| 102 |
assert spec["importance_color"] == "red" # high importance
|
| 103 |
|
|
|
|
| 153 |
def test_all_trial_nct_ids_in_packet(self):
|
| 154 |
ledgers = MOCK_ELIGIBILITY_LEDGERS
|
| 155 |
packet_ids = [lg.nct_id for lg in ledgers]
|
| 156 |
+
expected_ids = ["MOCK-NCT-KEYNOTE999", "MOCK-NCT-FLAURA2", "MOCK-NCT-CM817"]
|
| 157 |
assert packet_ids == expected_ids
|
| 158 |
|
| 159 |
|
trialpath/tests/test_mcp.py
CHANGED
|
@@ -33,21 +33,14 @@ class TestMCPClient:
|
|
| 33 |
def _mock_httpx(self, MockHTTP, response_data):
|
| 34 |
import json as _json
|
| 35 |
|
| 36 |
-
# Mock the
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
)
|
| 42 |
-
mock_stream_response.headers = {"content-type": "application/json"}
|
| 43 |
-
|
| 44 |
-
# client.stream() returns an async context manager
|
| 45 |
-
mock_stream_ctx = MagicMock()
|
| 46 |
-
mock_stream_ctx.__aenter__ = AsyncMock(return_value=mock_stream_response)
|
| 47 |
-
mock_stream_ctx.__aexit__ = AsyncMock(return_value=None)
|
| 48 |
|
| 49 |
mock_client = MagicMock()
|
| 50 |
-
mock_client.
|
| 51 |
|
| 52 |
# AsyncClient() itself is an async context manager
|
| 53 |
mock_client_ctx = MagicMock()
|
|
@@ -64,7 +57,7 @@ class TestMCPClient:
|
|
| 64 |
|
| 65 |
await client.search(sample_anchors)
|
| 66 |
|
| 67 |
-
call_args = mock_client.
|
| 68 |
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
| 69 |
query = body["params"]["arguments"]["query"]
|
| 70 |
|
|
@@ -79,7 +72,7 @@ class TestMCPClient:
|
|
| 79 |
|
| 80 |
await client.search(sample_anchors)
|
| 81 |
|
| 82 |
-
call_args = mock_client.
|
| 83 |
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
| 84 |
args = body["params"]["arguments"]
|
| 85 |
|
|
@@ -93,7 +86,7 @@ class TestMCPClient:
|
|
| 93 |
|
| 94 |
await client.search(sample_anchors)
|
| 95 |
|
| 96 |
-
call_args = mock_client.
|
| 97 |
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
| 98 |
filter_str = body["params"]["arguments"].get("filter", "")
|
| 99 |
|
|
@@ -135,7 +128,7 @@ class TestMCPClient:
|
|
| 135 |
country="United States",
|
| 136 |
)
|
| 137 |
|
| 138 |
-
call_args = mock_client.
|
| 139 |
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
| 140 |
args = body["params"]["arguments"]
|
| 141 |
|
|
|
|
| 33 |
def _mock_httpx(self, MockHTTP, response_data):
|
| 34 |
import json as _json
|
| 35 |
|
| 36 |
+
# Mock the response from client.post()
|
| 37 |
+
mock_response = MagicMock()
|
| 38 |
+
mock_response.raise_for_status = MagicMock()
|
| 39 |
+
mock_response.text = _json.dumps(response_data)
|
| 40 |
+
mock_response.headers = {"content-type": "application/json"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
mock_client = MagicMock()
|
| 43 |
+
mock_client.post = AsyncMock(return_value=mock_response)
|
| 44 |
|
| 45 |
# AsyncClient() itself is an async context manager
|
| 46 |
mock_client_ctx = MagicMock()
|
|
|
|
| 57 |
|
| 58 |
await client.search(sample_anchors)
|
| 59 |
|
| 60 |
+
call_args = mock_client.post.call_args
|
| 61 |
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
| 62 |
query = body["params"]["arguments"]["query"]
|
| 63 |
|
|
|
|
| 72 |
|
| 73 |
await client.search(sample_anchors)
|
| 74 |
|
| 75 |
+
call_args = mock_client.post.call_args
|
| 76 |
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
| 77 |
args = body["params"]["arguments"]
|
| 78 |
|
|
|
|
| 86 |
|
| 87 |
await client.search(sample_anchors)
|
| 88 |
|
| 89 |
+
call_args = mock_client.post.call_args
|
| 90 |
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
| 91 |
filter_str = body["params"]["arguments"].get("filter", "")
|
| 92 |
|
|
|
|
| 128 |
country="United States",
|
| 129 |
)
|
| 130 |
|
| 131 |
+
call_args = mock_client.post.call_args
|
| 132 |
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
| 133 |
args = body["params"]["arguments"]
|
| 134 |
|
trialpath/tests/test_medgemma.py
CHANGED
|
@@ -160,9 +160,13 @@ class TestMedGemmaHFEndpoint:
|
|
| 160 |
"""evaluate_medical_criterion should return decision dict."""
|
| 161 |
decision_data = {
|
| 162 |
"decision": "met",
|
| 163 |
-
"reasoning": "Patient has EGFR exon 19 del",
|
| 164 |
"confidence": 0.95,
|
| 165 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
mock_choice = MagicMock()
|
| 167 |
mock_choice.message.content = json.dumps(decision_data)
|
| 168 |
mock_response = MagicMock()
|
|
@@ -171,10 +175,19 @@ class TestMedGemmaHFEndpoint:
|
|
| 171 |
with patch("trialpath.services.medgemma_extractor.InferenceClient") as MockClient:
|
| 172 |
MockClient.return_value.chat_completion.return_value = mock_response
|
| 173 |
extractor = MedGemmaExtractor(endpoint_url="http://test", hf_token="tok")
|
| 174 |
-
result = await extractor.evaluate_medical_criterion(
|
|
|
|
|
|
|
| 175 |
|
| 176 |
assert result["decision"] == "met"
|
| 177 |
assert result["confidence"] == 0.95
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
@pytest.mark.asyncio
|
| 180 |
async def test_retry_on_cold_start(self):
|
|
|
|
| 160 |
"""evaluate_medical_criterion should return decision dict."""
|
| 161 |
decision_data = {
|
| 162 |
"decision": "met",
|
| 163 |
+
"reasoning": "Patient has EGFR exon 19 del per biomarkers",
|
| 164 |
"confidence": 0.95,
|
| 165 |
}
|
| 166 |
+
patient_profile = {
|
| 167 |
+
"diagnosis": {"primary_condition": "NSCLC", "stage": "IIIA"},
|
| 168 |
+
"biomarkers": [{"name": "EGFR", "result": "exon 19 del"}],
|
| 169 |
+
}
|
| 170 |
mock_choice = MagicMock()
|
| 171 |
mock_choice.message.content = json.dumps(decision_data)
|
| 172 |
mock_response = MagicMock()
|
|
|
|
| 175 |
with patch("trialpath.services.medgemma_extractor.InferenceClient") as MockClient:
|
| 176 |
MockClient.return_value.chat_completion.return_value = mock_response
|
| 177 |
extractor = MedGemmaExtractor(endpoint_url="http://test", hf_token="tok")
|
| 178 |
+
result = await extractor.evaluate_medical_criterion(
|
| 179 |
+
"EGFR mutation positive", patient_profile, []
|
| 180 |
+
)
|
| 181 |
|
| 182 |
assert result["decision"] == "met"
|
| 183 |
assert result["confidence"] == 0.95
|
| 184 |
+
# Verify patient profile was included in the prompt
|
| 185 |
+
call_args = MockClient.return_value.chat_completion.call_args
|
| 186 |
+
user_content = call_args[1]["messages"][1]["content"]
|
| 187 |
+
prompt_text = user_content[0]["text"]
|
| 188 |
+
assert "Patient Profile" in prompt_text
|
| 189 |
+
assert "EGFR" in prompt_text
|
| 190 |
+
assert "exon 19 del" in prompt_text
|
| 191 |
|
| 192 |
@pytest.mark.asyncio
|
| 193 |
async def test_retry_on_cold_start(self):
|