paperhawk / app /streaming.py
Nándorfi Vince
Initial paperhawk push to HF Space (LFS for binaries)
7ff7119
"""Streamlit + asyncio integration helper.
Bridges Streamlit (uvloop) and LangGraph (asyncio) via a long-lived background
event loop (see app/async_runtime.py).
``run_async()`` and ``stream_async()`` are simple wrappers — every call uses
the same background loop, so persistent resources (ChromaDB, AsyncSqliteSaver,
sentence-transformers cache) are NOT rebuilt per call.
``run_with_progress()`` produces per-event progress-bar updates from the
``astream(stream_mode="updates")`` event stream.
"""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Any, Callable
from app.async_runtime import AsyncRuntime
def run_async(coro):
"""Sync wrapper: run a coroutine on the long-lived background loop."""
return AsyncRuntime.get().submit(coro)
def stream_async(async_gen: AsyncIterator[Any]):
"""Async generator → sync iterator (compatible with Streamlit st.write_stream)."""
yield from AsyncRuntime.get().submit_iter(async_gen)
_PROGRESS_LABEL_MAP = {
"start_timer": "Starting",
"ingest_per_doc": "Loading documents",
"ingest_join": "Loading documents (join)",
"classify_per_doc": "Classifying",
"classify_join": "Classifying (join)",
"extract_per_doc": "Extracting structured data",
"extract_join": "Extracting (join)",
"quote_validator": "Quote verification",
"rag_index_per_doc": "Indexing",
"rag_join": "Indexing (join)",
"compare": "Cross-document checks",
"risk": "Risk analysis",
"report": "Generating report",
"finish_timer": "Done",
}
def run_with_progress(
graph,
input_state: dict,
on_progress: Callable[[int, int, str], None] | None = None,
total_steps: int | None = None,
) -> dict:
"""LangGraph ``astream`` → progress-bar callback + final state.
The background event loop drives the async generator; the ``on_progress``
callback runs on the CALLER thread (Streamlit main thread) after every
event — so ``st.progress(...)`` widgets can be updated safely.
Args:
graph: a CompiledStateGraph (or anything supporting astream).
input_state: the graph entry state.
on_progress: optional callback ``(step, total, label)``. Streamlit
widget calls are safe here (caller thread).
total_steps: optional progress-bar denominator.
Returns:
The graph's final state (same as ``ainvoke()``).
"""
async def _astream_events():
"""Async generator: split multi-stream-mode into (stream_mode, event) pairs."""
async for stream_mode, event in graph.astream(
input_state, stream_mode=["updates", "values"]
):
yield (stream_mode, event)
final_state: dict = {}
step = 0
# ``submit_iter`` turns an async iterator into a sync one on the caller thread,
# so the progress callback runs on the Streamlit main thread.
for stream_mode, event in AsyncRuntime.get().submit_iter(_astream_events()):
if stream_mode == "updates":
for node_name in (event or {}).keys():
step += 1
label = _PROGRESS_LABEL_MAP.get(node_name, node_name)
if on_progress is not None:
total = total_steps if total_steps is not None else max(step, 12)
on_progress(step, total, label)
elif stream_mode == "values":
if isinstance(event, dict):
final_state = event
return final_state