File size: 16,087 Bytes
d990530
 
 
d91cbff
d990530
d91cbff
d990530
 
 
 
 
 
 
 
 
d91cbff
 
 
 
 
 
 
 
d990530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d91cbff
d990530
d91cbff
 
 
d990530
d91cbff
 
 
 
 
 
 
d990530
 
 
 
f8b04c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08fce0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7b50cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d990530
 
caab91b
d990530
 
 
 
 
 
caab91b
d990530
 
 
 
 
 
 
d91cbff
d990530
 
 
 
caab91b
d990530
 
 
 
 
 
caab91b
d990530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caab91b
d990530
 
 
 
 
 
caab91b
d990530
 
 
 
 
d91cbff
 
 
 
caab91b
d91cbff
 
 
 
 
 
caab91b
d91cbff
 
 
 
 
 
 
 
 
caab91b
d91cbff
 
 
ed21ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1ee732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d91cbff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caab91b
d91cbff
 
 
 
 
 
caab91b
d91cbff
 
 
 
 
 
 
 
 
 
 
 
 
 
caab91b
d91cbff
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
from __future__ import annotations

import uuid
from contextlib import asynccontextmanager
from datetime import date
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from fastapi.testclient import TestClient

from civicsetu.models.enums import Jurisdiction, QueryType
from civicsetu.models.schemas import Citation
from tests.conftest import _base_state


@pytest.fixture(autouse=True)
def reset_graph_topology_cache():
    from civicsetu.api.routes import graph as graph_routes

    graph_routes._topo_cache["data"] = None
    graph_routes._topo_cache["ts"] = 0.0


def _make_citation(section_id: str = "18") -> Citation:
    return Citation(
        section_id=section_id,
        doc_name="RERA Act 2016",
        jurisdiction=Jurisdiction.CENTRAL,
        effective_date=date(2016, 5, 1),
        source_url="https://example.com/rera.pdf",
        chunk_id=uuid.uuid4(),
    )


@pytest.fixture
def client():
    with patch("civicsetu.agent.graph.get_compiled_graph") as mock_graph:
        mock_compiled = MagicMock()
        mock_graph.return_value = mock_compiled
        fake_checkpointer = AsyncMock()

        @asynccontextmanager
        async def fake_checkpointer_context():
            yield fake_checkpointer

        with patch("civicsetu.api.main.create_checkpointer", return_value=fake_checkpointer_context()):
            from civicsetu.api.main import create_app
            app = create_app()
            app.state.graph = mock_compiled

            with TestClient(app) as c:
                yield c, mock_compiled


# ── POST /api/v1/query ────────────────────────────────────────────────────────

def test_app_startup_warms_reranker_from_retrieval_module():
    fake_checkpointer = AsyncMock()

    @asynccontextmanager
    async def fake_checkpointer_context():
        yield fake_checkpointer

    with patch("civicsetu.api.main.create_checkpointer", return_value=fake_checkpointer_context()), patch(
        "civicsetu.agent.graph.get_compiled_graph", return_value=MagicMock()
    ), patch("civicsetu.api.main.get_driver", new=AsyncMock()), patch(
        "civicsetu.api.main.close_driver", new=AsyncMock()
    ), patch("civicsetu.retrieval.warm_embedding_model"), patch(
        "civicsetu.retrieval.reranker._get_ranker"
    ) as mock_get_ranker:
        from civicsetu.api.main import create_app

        app = create_app()

        with TestClient(app):
            pass

    mock_get_ranker.assert_called_once()


def test_app_startup_on_non_windows_does_not_shadow_asyncio():
    fake_checkpointer = AsyncMock()

    @asynccontextmanager
    async def fake_checkpointer_context():
        yield fake_checkpointer

    with patch("civicsetu.api.main.sys.platform", "linux"), patch(
        "civicsetu.api.main.create_checkpointer", return_value=fake_checkpointer_context()
    ), patch("civicsetu.agent.graph.get_compiled_graph", return_value=MagicMock()), patch(
        "civicsetu.api.main.get_driver", new=AsyncMock()
    ), patch("civicsetu.api.main.close_driver", new=AsyncMock()), patch(
        "civicsetu.retrieval.warm_embedding_model"
    ) as mock_warm_embedding_model, patch(
        "civicsetu.retrieval.reranker._get_ranker"
    ) as mock_get_ranker:
        from civicsetu.api.main import create_app

        app = create_app()

        with TestClient(app):
            pass

    mock_warm_embedding_model.assert_called_once()
    mock_get_ranker.assert_called_once()


def test_root_returns_ok_when_frontend_not_served(client):
    test_client, _ = client

    response = test_client.get("/")

    assert response.status_code == 200
    assert response.json()["status"] == "ok"


def test_graph_topology_returns_empty_payload_when_neo4j_auth_fails(client):
    test_client, _ = client

    with patch(
        "civicsetu.api.routes.graph.GraphStore.get_topology",
        new=AsyncMock(side_effect=RuntimeError("Neo.ClientError.Security.Unauthorized")),
    ), patch(
        "civicsetu.api.routes.graph.GraphStore.graph_stats",
        new=AsyncMock(side_effect=RuntimeError("Neo.ClientError.Security.Unauthorized")),
    ):
        response = test_client.get("/api/v1/graph/topology")

    assert response.status_code == 200
    assert response.json() == {
        "nodes": [],
        "edges": [],
        "stats": {
            "docs": 0,
            "sections": 0,
            "refs": 0,
            "has_sec": 0,
            "derived_from": 0,
        },
    }


def test_query_returns_200_with_citations(client):
    test_client, mock_graph = client
    mock_graph.ainvoke = AsyncMock(return_value={
        "raw_response": "Under Section 18, the promoter must...",
        "citations": [_make_citation("18")],
        "confidence_score": 0.9,
        "query_type": QueryType.CROSS_REFERENCE,
        "conflict_warnings": [],
        "amendment_notice": None,
    })

    response = test_client.post("/api/v1/query", json={"query": "What does Section 18 say?"})
    assert response.status_code == 200
    body = response.json()
    assert "answer" in body
    assert len(body["citations"]) == 1
    assert body["citations"][0]["section_id"] == "18"
    assert "session_id" in body


def test_query_returns_insufficient_when_no_citations(client):
    test_client, mock_graph = client
    mock_graph.ainvoke = AsyncMock(return_value={
        "raw_response": "I don't know",
        "citations": [],
        "confidence_score": 0.2,
        "query_type": QueryType.FACT_LOOKUP,
        "conflict_warnings": [],
        "amendment_notice": None,
    })

    response = test_client.post("/api/v1/query", json={"query": "Some obscure question here"})
    assert response.status_code == 200
    body = response.json()
    assert "Insufficient" in body["answer"]


def test_query_rejects_short_query(client):
    test_client, _ = client
    response = test_client.post("/api/v1/query", json={"query": "hi"})
    assert response.status_code == 422


def test_query_rejects_top_k_out_of_range(client):
    test_client, _ = client
    response = test_client.post(
        "/api/v1/query",
        json={"query": "What does Section 18 say?", "top_k": 25},
    )
    assert response.status_code == 422


def test_query_response_always_has_disclaimer(client):
    test_client, mock_graph = client
    mock_graph.ainvoke = AsyncMock(return_value={
        "raw_response": "Under Section 18...",
        "citations": [_make_citation()],
        "confidence_score": 0.8,
        "query_type": QueryType.FACT_LOOKUP,
        "conflict_warnings": [],
        "amendment_notice": None,
    })

    response = test_client.post("/api/v1/query", json={"query": "What does Section 18 say?"})
    body = response.json()
    assert "disclaimer" in body
    assert len(body["disclaimer"]) > 0


def test_query_reuses_provided_session_id(client):
    test_client, mock_graph = client
    mock_graph.ainvoke = AsyncMock(return_value={
        "raw_response": "Under Section 18...",
        "citations": [_make_citation()],
        "confidence_score": 0.8,
        "query_type": QueryType.FACT_LOOKUP,
        "conflict_warnings": [],
        "amendment_notice": None,
    })

    response = test_client.post(
        "/api/v1/query",
        json={"query": "What does Section 18 say?", "session_id": "my-session-abc"},
    )

    body = response.json()
    assert response.status_code == 200
    assert body["session_id"] == "my-session-abc"
    _, config = mock_graph.ainvoke.call_args.args
    assert config["configurable"]["thread_id"] == "my-session-abc"


def test_query_retries_once_on_transient_graph_failure(client):
    test_client, mock_graph = client
    mock_graph.ainvoke = AsyncMock(side_effect=[
        RuntimeError("terminating connection due to administrator command"),
        {
            "raw_response": "Under Section 18...",
            "citations": [_make_citation()],
            "confidence_score": 0.8,
            "query_type": QueryType.FACT_LOOKUP,
            "conflict_warnings": [],
            "amendment_notice": None,
        },
    ])

    response = test_client.post("/api/v1/query", json={"query": "What does Section 18 say?"})

    assert response.status_code == 200
    assert mock_graph.ainvoke.await_count == 2


def test_query_preflights_driver_before_invoke(client):
    test_client, mock_graph = client
    mock_graph.ainvoke = AsyncMock(return_value={
        "raw_response": "Under Section 18...",
        "citations": [_make_citation()],
        "confidence_score": 0.8,
        "query_type": QueryType.FACT_LOOKUP,
        "conflict_warnings": [],
        "amendment_notice": None,
    })

    with patch("civicsetu.api.routes.query.refresh_driver", new=AsyncMock()) as refresh_driver:
        response = test_client.post("/api/v1/query", json={"query": "What does Section 18 say?"})

    assert response.status_code == 200
    refresh_driver.assert_awaited_once()


def test_graph_topology_returns_nodes_and_edges(client):
    test_client, _ = client

    with patch("civicsetu.api.routes.graph.GraphStore.get_topology", new=AsyncMock(return_value=(
        [
            {
                "chunk_id": "chunk-18",
                "section_id": "18",
                "title": "Return of amount and compensation",
                "jurisdiction": "CENTRAL",
                "doc_name": "RERA Act 2016",
                "is_active": True,
                "connection_count": 3,
            }
        ],
        [
            {
                "source": "chunk-18",
                "target": "chunk-19",
                "edge_type": "REFERENCES",
            }
        ],
    ))), patch(
        "civicsetu.api.routes.graph.GraphStore.graph_stats",
        new=AsyncMock(return_value={"sections": 10, "refs": 4, "derived_from": 2}),
    ):
        response = test_client.get("/api/v1/graph/topology")

    assert response.status_code == 200
    body = response.json()
    assert body["nodes"][0]["section_id"] == "18"
    assert body["edges"][0]["edge_type"] == "REFERENCES"
    assert body["stats"]["sections"] == 10


def test_graph_topology_uses_cache(client):
    test_client, _ = client
    with patch("civicsetu.api.routes.graph.GraphStore.get_topology", new=AsyncMock(return_value=([], []))) as mock_topology, patch(
        "civicsetu.api.routes.graph.GraphStore.graph_stats",
        new=AsyncMock(return_value={"sections": 0}),
    ):
        first = test_client.get("/api/v1/graph/topology")
        second = test_client.get("/api/v1/graph/topology")

    assert first.status_code == 200
    assert second.status_code == 200
    assert mock_topology.await_count == 1


def test_graph_section_content_returns_stitched_chunks_and_connections(client):
    test_client, _ = client
    fake_rows = [
        {
            "chunk_id": "chunk-18-a",
            "section_title": "Return of amount and compensation",
            "text": "First chunk.",
            "page_number": 1,
            "doc_name": "RERA Act 2016",
            "jurisdiction": "CENTRAL",
            "effective_date": date(2016, 5, 1),
            "source_url": "https://example.com/rera.pdf",
        },
        {
            "chunk_id": "chunk-18-b",
            "section_title": "Return of amount and compensation",
            "text": "Second chunk.",
            "page_number": 2,
            "doc_name": "RERA Act 2016",
            "jurisdiction": "CENTRAL",
            "effective_date": date(2016, 5, 1),
            "source_url": "https://example.com/rera.pdf",
        },
    ]

    mock_result = MagicMock()
    mock_result.mappings.return_value.all.return_value = fake_rows
    mock_db = AsyncMock()
    mock_db.execute.return_value = mock_result

    @asynccontextmanager
    async def fake_session():
        yield mock_db

    with patch("civicsetu.api.routes.graph.AsyncSessionLocal", side_effect=fake_session), patch(
        "civicsetu.api.routes.graph.GraphStore.get_referenced_sections",
        new=AsyncMock(return_value=[{"section_id": "19", "title": "Rights", "jurisdiction": "CENTRAL"}]),
    ), patch(
        "civicsetu.api.routes.graph.GraphStore.get_sections_referencing",
        new=AsyncMock(return_value=[]),
    ), patch(
        "civicsetu.api.routes.graph.GraphStore.get_derived_act_sections",
        new=AsyncMock(return_value=[]),
    ), patch(
        "civicsetu.api.routes.graph.GraphStore.get_deriving_rule_sections",
        new=AsyncMock(return_value=[]),
    ):
        response = test_client.get("/api/v1/graph/section/18?jurisdiction=CENTRAL")

    assert response.status_code == 200
    body = response.json()
    assert [chunk["page_number"] for chunk in body["chunks"]] == [1, 2]
    assert body["connected_sections"][0]["edge_type"] == "REFERENCES_OUT"
    assert body["source_url"] == "https://example.com/rera.pdf"


def test_graph_section_content_resolves_graph_chunk_id_case_insensitive_status(client):
    test_client, _ = client

    node_result = MagicMock()
    node_result.mappings.return_value.first.return_value = {
        "section_id": "4",
        "jurisdiction": "CENTRAL",
    }
    chunks_result = MagicMock()
    chunks_result.mappings.return_value.all.return_value = [
        {
            "chunk_id": "postgres-chunk-4",
            "section_id": "4",
            "section_title": "Prior registration of real estate project",
            "text": "Section text.",
            "page_number": 1,
            "doc_name": "RERA Act 2016",
            "jurisdiction": "CENTRAL",
            "effective_date": date(2016, 5, 1),
            "source_url": "https://example.com/rera.pdf",
        }
    ]
    mock_db = AsyncMock()
    mock_db.execute.side_effect = [node_result, chunks_result]

    @asynccontextmanager
    async def fake_session():
        yield mock_db

    with patch("civicsetu.api.routes.graph.AsyncSessionLocal", side_effect=fake_session), patch(
        "civicsetu.api.routes.graph.GraphStore.get_referenced_sections",
        new=AsyncMock(return_value=[]),
    ), patch(
        "civicsetu.api.routes.graph.GraphStore.get_sections_referencing",
        new=AsyncMock(return_value=[]),
    ), patch(
        "civicsetu.api.routes.graph.GraphStore.get_derived_act_sections",
        new=AsyncMock(return_value=[]),
    ), patch(
        "civicsetu.api.routes.graph.GraphStore.get_deriving_rule_sections",
        new=AsyncMock(return_value=[]),
    ):
        response = test_client.get(
            "/api/v1/graph/section/4?jurisdiction=CENTRAL&chunk_id=72ecb226-2aae-4c6c-bd23-5dc3a9439642"
        )

    assert response.status_code == 200
    assert response.json()["chunks"][0]["chunk_id"] == "postgres-chunk-4"
    first_lookup_sql = str(mock_db.execute.await_args_list[0].args[0])
    section_lookup_sql = str(mock_db.execute.await_args_list[1].args[0])
    assert "lower(status) = 'active'" in first_lookup_sql
    assert "lower(status) = 'active'" in section_lookup_sql


def test_section_context_query_sets_skip_classifier_and_source_section(client):
    test_client, mock_graph = client
    mock_graph.ainvoke = AsyncMock(return_value={
        "raw_response": "Section 18 requires refund with interest.",
        "citations": [_make_citation("18")],
        "confidence_score": 0.88,
        "query_type": QueryType.CROSS_REFERENCE,
        "conflict_warnings": [],
        "amendment_notice": None,
    })

    response = test_client.post(
        "/api/v1/query/section-context",
        json={
            "query": "Explain this section",
            "section_id": "18",
            "jurisdiction": "CENTRAL",
            "session_id": "section-thread-1",
        },
    )

    assert response.status_code == 200
    body = response.json()
    assert body["session_id"] == "section-thread-1"
    initial_state, config = mock_graph.ainvoke.call_args.args
    assert initial_state["skip_classifier"] is True
    assert initial_state["source_section_id"] == "18"
    assert initial_state["query_type"] == QueryType.CROSS_REFERENCE
    assert config["configurable"]["thread_id"] == "section-thread-1"