File size: 1,722 Bytes
d91cbff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from civicsetu.retrieval.cache import graph_cache, make_key
@pytest.mark.asyncio
async def test_graph_retrieve_cache_hit():
graph_cache.clear()
fake_chunks = [MagicMock()]
graph_cache[make_key("18", "all", 2)] = fake_chunks
with patch(
"civicsetu.stores.graph_store.GraphStore.get_referenced_sections",
side_effect=RuntimeError("should not run"),
):
from civicsetu.retrieval.graph_retriever import GraphRetriever
result = await GraphRetriever.retrieve(
query="What does Section 18 reference?",
jurisdiction=None,
)
assert result == fake_chunks
@pytest.mark.asyncio
async def test_graph_retrieve_cache_miss_populates_cache():
graph_cache.clear()
with (
patch("civicsetu.stores.graph_store.GraphStore.get_referenced_sections", new=AsyncMock(return_value=[])),
patch("civicsetu.stores.graph_store.GraphStore.get_sections_referencing", new=AsyncMock(return_value=[])),
patch("civicsetu.stores.graph_store.GraphStore.get_derived_act_sections", new=AsyncMock(return_value=[])),
patch("civicsetu.stores.graph_store.GraphStore.get_deriving_rule_sections", new=AsyncMock(return_value=[])),
patch("civicsetu.stores.vector_store.VectorStore.get_by_section", new=AsyncMock(return_value=[])),
):
from civicsetu.retrieval.graph_retriever import GraphRetriever
result = await GraphRetriever.retrieve(
query="What does Section 18 say?",
jurisdiction=None,
)
assert result == []
assert make_key("18", "all", 2) in graph_cache
|