paperhawk / graph /states /pipeline_state.py
Nándorfi Vince
Initial paperhawk push to HF Space (LFS for binaries)
7ff7119
raw
history blame
9.27 kB
"""PipelineState — global state of the main pipeline graph.
LangGraph TypedDict (because of the reducers), Pydantic v2 models in the fields
(runtime field validation). Every Send API fan-out / fan-in is collapsed via
the ``merge_doc_results`` and ``merge_risks`` reducers.
Pydantic models with ``dict`` fields (e.g. ``ExtractedData.raw``) are NOT
schema-validated — the JSON-schema-level validation is provided by
``validation/quote_validator.py`` and the runtime checks in
``schemas/pydantic_models.py``.
"""
from __future__ import annotations
from datetime import datetime
from operator import add
from typing import Annotated, TypedDict
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Pydantic models (used inside the TypedDict fields)
# ---------------------------------------------------------------------------
class PageContent(BaseModel):
"""Content of a single page (PDF/DOCX/PNG ingest output)."""
page_number: int = 1
text: str = ""
is_scanned: bool = False
"""In the PDF loader's three-tier fallback: if PyMuPDF native text < 50 chars,
is_scanned=True and we fall through to Tesseract OCR / LLM vision."""
image_bytes: bytes | None = None
"""Set only if ``is_scanned=True`` and we go down the vision-first path
in extract (or if the input is a .png/.jpg — vision-first by default).
Raw image bytes."""
class IngestedDocument(BaseModel):
"""Output of the ingest_subgraph for a single document."""
file_name: str
file_type: str # pdf | docx | png | jpg | txt
pages: list[PageContent] = Field(default_factory=list)
full_text: str = ""
"""Concatenation of all page texts with \\n\\n separator. Fed into the
chunker for RAG."""
tables_markdown: str = ""
"""Tables extracted by pdfplumber, formatted as Markdown."""
table_count: int = 0
is_scanned: bool = False
"""True if at least one page is scanned and structured data can only be
extracted via the vision path."""
class Classification(BaseModel):
"""Output of the classify_node."""
doc_type: str
"""invoice | delivery_note | purchase_order | contract | financial_report | other"""
doc_type_display: str
"""Display label for the UI: 'Invoice', 'Contract', etc."""
confidence: float = Field(ge=0.0, le=1.0)
language: str = "en" # en | hu | de | fr | ...
used_vision: bool = False
"""True if classification was done via the vision-structured path (scanned doc)."""
class ExtractedData(BaseModel):
"""Output of the extract_subgraph for a single document.
The ``raw`` dict contains the JSON-schema payload (e.g. invoice.json fields).
The ``_quotes``, ``_confidence``, ``_source`` aliased fields are kept
SEPARATELY because they are anti-hallucination layers: domain checks read
``raw`` (typed names), but chat tools return the full ExtractedData JSON.
"""
raw: dict = Field(default_factory=dict)
quotes: list[str] = Field(default_factory=list, alias="_quotes")
confidence: dict = Field(default_factory=dict, alias="_confidence")
source: dict = Field(default_factory=dict, alias="_source")
model_config = {"populate_by_name": True}
class Risk(BaseModel):
"""A single risk / finding — every risk source uses this unified format."""
description: str
severity: str # high | medium | low | info
rationale: str = ""
kind: str # validation | domain_rule | plausibility | llm_analysis | cross_check
regulation: str | None = None
affected_document: str | None = None
source_check_id: str | None = None
"""For domain-check risks: which check generated this (debug + filtering)."""
class ProcessedDocument(BaseModel):
"""End-to-end result for a single document: ingest + classify + extract + risks."""
ingested: IngestedDocument
classification: Classification | None = None
extracted: ExtractedData | None = None
risks: list[Risk] = Field(default_factory=list)
"""Document-level risks (NOT routed into the global state['risks'] —
that one is centrally aggregated)."""
rag_chunks_indexed: int = 0
processing_seconds: float = 0.0
class ComparisonReport(BaseModel):
"""Output of three-way matching (compare_node).
The ``matches`` items are dict-shaped MatchResult records:
``{status, severity, message, field_name, expected, actual, source_a, source_b}``.
"""
invoice_filename: str | None = None
delivery_note_filename: str | None = None
purchase_order_filename: str | None = None
matches: list[dict] = Field(default_factory=list)
# Aggregated counters
total_checks: int = 0
ok_count: int = 0
warning_count: int = 0
critical_count: int = 0
missing_count: int = 0
overall_status: str = "ok" # ok | warning | critical | missing
summary: str = ""
# Forward-references for Phase 6 models
class DDPortfolioReport(BaseModel):
"""Forward stub for the Phase 6 DD assistant output."""
contract_count: int = 0
contracts: list[dict] = Field(default_factory=list)
total_monthly_obligations: dict[str, float] = Field(default_factory=dict)
expiring_soon: list[str] = Field(default_factory=list)
high_risk_contracts: list[str] = Field(default_factory=list)
top_red_flags: list[str] = Field(default_factory=list)
executive_summary: str = ""
specialist_outputs: dict = Field(default_factory=dict)
class PackageInsights(BaseModel):
"""Forward stub for the Phase 6 package insights output."""
executive_summary: str = ""
findings: list[dict] = Field(default_factory=list)
key_observations: list[str] = Field(default_factory=list)
package_type: str = "general"
# ---------------------------------------------------------------------------
# Reducers for the Send API fan-in
# ---------------------------------------------------------------------------
def merge_doc_results(
left: list[ProcessedDocument],
right: list[ProcessedDocument],
) -> list[ProcessedDocument]:
"""Send fan-in: FIELD-LEVEL merge keyed by file_name.
If different per-doc Send fan-out nodes update separate fields of the same
document (e.g. classify_per_doc → classification, rag_index_per_doc →
rag_chunks_indexed), the reducer does NOT clobber already-set fields — it
only refreshes not-None new values.
The reducer is ASSOCIATIVE and PURE.
"""
by_name: dict[str, ProcessedDocument] = {
d.ingested.file_name: d for d in left if d.ingested
}
for d in right:
if d.ingested is None:
continue
existing = by_name.get(d.ingested.file_name)
if existing is None:
by_name[d.ingested.file_name] = d
continue
# Field-level merge: only NOT-NONE new values overwrite
update: dict = {}
if d.classification is not None:
update["classification"] = d.classification
if d.extracted is not None:
update["extracted"] = d.extracted
if d.risks:
update["risks"] = d.risks
if d.rag_chunks_indexed:
update["rag_chunks_indexed"] = d.rag_chunks_indexed
if d.processing_seconds:
update["processing_seconds"] = d.processing_seconds
if update:
by_name[d.ingested.file_name] = existing.model_copy(update=update)
return list(by_name.values())
def merge_risks(left: list[Risk], right: list[Risk]) -> list[Risk]:
"""Risk dedup keyed by description (mirrors the prototype-agentic _add_risk).
First occurrence wins (left order preserved). A risk duplicates iff the
exact same description string appears — common because comparison risks are
document-independent and a per-doc loop would re-add them each iteration.
"""
seen = {r.description for r in left}
out = list(left)
for r in right:
if r.description not in seen:
out.append(r)
seen.add(r.description)
return out
# ---------------------------------------------------------------------------
# PipelineState TypedDict — the full graph state
# ---------------------------------------------------------------------------
class PipelineState(TypedDict, total=False):
"""The main pipeline graph state. Every node reads/updates this.
``total=False`` indicates that all fields are optional (not all initialized
at START). Send API fan-out branches write back into ``documents`` and
``risks`` via the reducers above.
"""
# Input
files: list[tuple[str, bytes]]
"""[(file_name, file_bytes), ...] — fed in from the Streamlit upload."""
# Per-doc fan-out / fan-in (with reducers)
documents: Annotated[list[ProcessedDocument], merge_doc_results]
risks: Annotated[list[Risk], merge_risks]
# Aggregated outputs
comparison: ComparisonReport | None
report: dict
package_insights: PackageInsights | None
dd_report: DDPortfolioReport | None
# Timing / progress
started_at: datetime
finished_at: datetime
processing_seconds: float
progress_events: Annotated[list[str], add]
"""Each node tick appends a string (Streamlit progress bar feed)."""