File size: 3,497 Bytes
7ff7119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Streamlit + asyncio integration helper.

Bridges Streamlit (uvloop) and LangGraph (asyncio) via a long-lived background
event loop (see app/async_runtime.py).

``run_async()`` and ``stream_async()`` are simple wrappers — every call uses
the same background loop, so persistent resources (ChromaDB, AsyncSqliteSaver,
sentence-transformers cache) are NOT rebuilt per call.

``run_with_progress()`` produces per-event progress-bar updates from the
``astream(stream_mode="updates")`` event stream.
"""

from __future__ import annotations

from collections.abc import AsyncIterator
from typing import Any, Callable

from app.async_runtime import AsyncRuntime


def run_async(coro):
    """Sync wrapper: run a coroutine on the long-lived background loop."""
    return AsyncRuntime.get().submit(coro)


def stream_async(async_gen: AsyncIterator[Any]):
    """Async generator → sync iterator (compatible with Streamlit st.write_stream)."""
    yield from AsyncRuntime.get().submit_iter(async_gen)


_PROGRESS_LABEL_MAP = {
    "start_timer": "Starting",
    "ingest_per_doc": "Loading documents",
    "ingest_join": "Loading documents (join)",
    "classify_per_doc": "Classifying",
    "classify_join": "Classifying (join)",
    "extract_per_doc": "Extracting structured data",
    "extract_join": "Extracting (join)",
    "quote_validator": "Quote verification",
    "rag_index_per_doc": "Indexing",
    "rag_join": "Indexing (join)",
    "compare": "Cross-document checks",
    "risk": "Risk analysis",
    "report": "Generating report",
    "finish_timer": "Done",
}


def run_with_progress(
    graph,
    input_state: dict,
    on_progress: Callable[[int, int, str], None] | None = None,
    total_steps: int | None = None,
) -> dict:
    """LangGraph ``astream`` → progress-bar callback + final state.

    The background event loop drives the async generator; the ``on_progress``
    callback runs on the CALLER thread (Streamlit main thread) after every
    event — so ``st.progress(...)`` widgets can be updated safely.

    Args:
        graph: a CompiledStateGraph (or anything supporting astream).
        input_state: the graph entry state.
        on_progress: optional callback ``(step, total, label)``. Streamlit
                     widget calls are safe here (caller thread).
        total_steps: optional progress-bar denominator.

    Returns:
        The graph's final state (same as ``ainvoke()``).
    """

    async def _astream_events():
        """Async generator: split multi-stream-mode into (stream_mode, event) pairs."""
        async for stream_mode, event in graph.astream(
            input_state, stream_mode=["updates", "values"]
        ):
            yield (stream_mode, event)

    final_state: dict = {}
    step = 0

    # ``submit_iter`` turns an async iterator into a sync one on the caller thread,
    # so the progress callback runs on the Streamlit main thread.
    for stream_mode, event in AsyncRuntime.get().submit_iter(_astream_events()):
        if stream_mode == "updates":
            for node_name in (event or {}).keys():
                step += 1
                label = _PROGRESS_LABEL_MAP.get(node_name, node_name)
                if on_progress is not None:
                    total = total_steps if total_steps is not None else max(step, 12)
                    on_progress(step, total, label)
        elif stream_mode == "values":
            if isinstance(event, dict):
                final_state = event

    return final_state