File size: 4,963 Bytes
7ff7119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""chat_graph integration test with the dummy LLM.

For each of the 5 intents (list, extract, search, compare, validate), the right
tool sequence runs and the validator's anti-hallucination check does not block.
"""

from __future__ import annotations

import pytest

from store import HybridStore


@pytest.fixture
def populated_context(sample_pdf_bytes, tmp_path):
    """A ChatToolContext with one invoice PDF run through the pipeline."""
    import asyncio

    from graph.pipeline_graph import build_pipeline_graph
    from tools import ChatToolContext

    store = HybridStore(
        chroma_path=str(tmp_path / "chat_chroma"),
        collection_name="chat_test",
    )
    pipeline = build_pipeline_graph(store)
    pipeline_state = asyncio.run(pipeline.ainvoke({
        "files": [
            ("invoice_january.pdf", sample_pdf_bytes),
            ("invoice_march.pdf", sample_pdf_bytes),
        ],
    }))

    context = ChatToolContext(store=store)
    for pd in pipeline_state.get("documents") or []:
        context.add_document(pd)
    return context


@pytest.mark.integration
@pytest.mark.asyncio
async def test_chat_list_intent(populated_context):
    """'What files do we have' β†’ list_documents tool."""
    from langchain_core.messages import HumanMessage

    from graph.chat_graph import build_chat_graph
    from providers import get_chat_model, get_dummy_handle

    dummy = get_dummy_handle()
    dummy.set_docs_hint(populated_context.list_filenames())

    llm = get_chat_model("dummy")
    graph = build_chat_graph(llm, populated_context)

    state = await graph.ainvoke({
        "messages": [HumanMessage(content="What documents are uploaded?")],
    })

    assert state.get("intent") == "list"
    assert "list_documents" in (state.get("plan") or [])
    assert state.get("final_answer", "")  # non-empty


@pytest.mark.integration
@pytest.mark.asyncio
async def test_chat_validate_intent(populated_context):
    """'Validate the math on the invoice' β†’ validate_document tool."""
    from langchain_core.messages import HumanMessage

    from graph.chat_graph import build_chat_graph
    from providers import get_chat_model, get_dummy_handle

    dummy = get_dummy_handle()
    dummy.set_docs_hint(populated_context.list_filenames())

    llm = get_chat_model("dummy")
    graph = build_chat_graph(llm, populated_context)

    state = await graph.ainvoke({
        "messages": [HumanMessage(content="Validate the math on invoice_january.pdf")],
    })

    assert state.get("intent") == "validate"
    # iter_count >= 1 (at least one tool call ran)
    assert state.get("iteration_count", 0) >= 1


@pytest.mark.integration
@pytest.mark.asyncio
async def test_chat_compare_intent(populated_context):
    """'Compare X and Y' β†’ compare_documents flow."""
    from langchain_core.messages import HumanMessage

    from graph.chat_graph import build_chat_graph
    from providers import get_chat_model, get_dummy_handle

    dummy = get_dummy_handle()
    dummy.set_docs_hint(populated_context.list_filenames())

    llm = get_chat_model("dummy")
    graph = build_chat_graph(llm, populated_context)

    state = await graph.ainvoke({
        "messages": [HumanMessage(content="Compare the January and March invoices")],
    })

    assert state.get("intent") == "compare"
    plan = state.get("plan") or []
    assert "compare_documents" in plan
    # compare flow: list β†’ get Γ— 2 β†’ compare β†’ synth β‡’ at least 4 iters
    assert state.get("iteration_count", 0) >= 1


@pytest.mark.integration
@pytest.mark.asyncio
async def test_chat_search_intent(populated_context):
    """'Find the penalty clause' β†’ search_documents tool (RAG)."""
    from langchain_core.messages import HumanMessage

    from graph.chat_graph import build_chat_graph
    from providers import get_chat_model, get_dummy_handle

    dummy = get_dummy_handle()
    dummy.set_docs_hint(populated_context.list_filenames())

    llm = get_chat_model("dummy")
    graph = build_chat_graph(llm, populated_context)

    state = await graph.ainvoke({
        "messages": [HumanMessage(content="Find the penalty clause")],
    })

    assert state.get("intent") == "search"
    assert state.get("iteration_count", 0) >= 1


@pytest.mark.integration
@pytest.mark.asyncio
async def test_chat_extract_intent(populated_context):
    """'What is the gross total' β†’ extract flow."""
    from langchain_core.messages import HumanMessage

    from graph.chat_graph import build_chat_graph
    from providers import get_chat_model, get_dummy_handle

    dummy = get_dummy_handle()
    dummy.set_docs_hint(populated_context.list_filenames())

    llm = get_chat_model("dummy")
    graph = build_chat_graph(llm, populated_context)

    state = await graph.ainvoke({
        "messages": [HumanMessage(content="What is the gross total on invoice_january.pdf?")],
    })

    assert state.get("intent") == "extract"
    assert state.get("iteration_count", 0) >= 1