| """Streamlit + asyncio integration helper. |
| |
| Bridges Streamlit (uvloop) and LangGraph (asyncio) via a long-lived background |
| event loop (see app/async_runtime.py). |
| |
| ``run_async()`` and ``stream_async()`` are simple wrappers — every call uses |
| the same background loop, so persistent resources (ChromaDB, AsyncSqliteSaver, |
| sentence-transformers cache) are NOT rebuilt per call. |
| |
| ``run_with_progress()`` produces per-event progress-bar updates from the |
| ``astream(stream_mode="updates")`` event stream. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import AsyncIterator |
| from typing import Any, Callable |
|
|
| from app.async_runtime import AsyncRuntime |
|
|
|
|
| def run_async(coro): |
| """Sync wrapper: run a coroutine on the long-lived background loop.""" |
| return AsyncRuntime.get().submit(coro) |
|
|
|
|
| def stream_async(async_gen: AsyncIterator[Any]): |
| """Async generator → sync iterator (compatible with Streamlit st.write_stream).""" |
| yield from AsyncRuntime.get().submit_iter(async_gen) |
|
|
|
|
| _PROGRESS_LABEL_MAP = { |
| "start_timer": "Starting", |
| "ingest_per_doc": "Loading documents", |
| "ingest_join": "Loading documents (join)", |
| "classify_per_doc": "Classifying", |
| "classify_join": "Classifying (join)", |
| "extract_per_doc": "Extracting structured data", |
| "extract_join": "Extracting (join)", |
| "quote_validator": "Quote verification", |
| "rag_index_per_doc": "Indexing", |
| "rag_join": "Indexing (join)", |
| "compare": "Cross-document checks", |
| "risk": "Risk analysis", |
| "report": "Generating report", |
| "finish_timer": "Done", |
| } |
|
|
|
|
| def run_with_progress( |
| graph, |
| input_state: dict, |
| on_progress: Callable[[int, int, str], None] | None = None, |
| total_steps: int | None = None, |
| ) -> dict: |
| """LangGraph ``astream`` → progress-bar callback + final state. |
| |
| The background event loop drives the async generator; the ``on_progress`` |
| callback runs on the CALLER thread (Streamlit main thread) after every |
| event — so ``st.progress(...)`` widgets can be updated safely. |
| |
| Args: |
| graph: a CompiledStateGraph (or anything supporting astream). |
| input_state: the graph entry state. |
| on_progress: optional callback ``(step, total, label)``. Streamlit |
| widget calls are safe here (caller thread). |
| total_steps: optional progress-bar denominator. |
| |
| Returns: |
| The graph's final state (same as ``ainvoke()``). |
| """ |
|
|
| async def _astream_events(): |
| """Async generator: split multi-stream-mode into (stream_mode, event) pairs.""" |
| async for stream_mode, event in graph.astream( |
| input_state, stream_mode=["updates", "values"] |
| ): |
| yield (stream_mode, event) |
|
|
| final_state: dict = {} |
| step = 0 |
|
|
| |
| |
| for stream_mode, event in AsyncRuntime.get().submit_iter(_astream_events()): |
| if stream_mode == "updates": |
| for node_name in (event or {}).keys(): |
| step += 1 |
| label = _PROGRESS_LABEL_MAP.get(node_name, node_name) |
| if on_progress is not None: |
| total = total_steps if total_steps is not None else max(step, 12) |
| on_progress(step, total, label) |
| elif stream_mode == "values": |
| if isinstance(event, dict): |
| final_state = event |
|
|
| return final_state |
|
|