File size: 9,272 Bytes
7ff7119 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 | """PipelineState β global state of the main pipeline graph.
LangGraph TypedDict (because of the reducers), Pydantic v2 models in the fields
(runtime field validation). Every Send API fan-out / fan-in is collapsed via
the ``merge_doc_results`` and ``merge_risks`` reducers.
Pydantic models with ``dict`` fields (e.g. ``ExtractedData.raw``) are NOT
schema-validated β the JSON-schema-level validation is provided by
``validation/quote_validator.py`` and the runtime checks in
``schemas/pydantic_models.py``.
"""
from __future__ import annotations
from datetime import datetime
from operator import add
from typing import Annotated, TypedDict
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Pydantic models (used inside the TypedDict fields)
# ---------------------------------------------------------------------------
class PageContent(BaseModel):
"""Content of a single page (PDF/DOCX/PNG ingest output)."""
page_number: int = 1
text: str = ""
is_scanned: bool = False
"""In the PDF loader's three-tier fallback: if PyMuPDF native text < 50 chars,
is_scanned=True and we fall through to Tesseract OCR / LLM vision."""
image_bytes: bytes | None = None
"""Set only if ``is_scanned=True`` and we go down the vision-first path
in extract (or if the input is a .png/.jpg β vision-first by default).
Raw image bytes."""
class IngestedDocument(BaseModel):
"""Output of the ingest_subgraph for a single document."""
file_name: str
file_type: str # pdf | docx | png | jpg | txt
pages: list[PageContent] = Field(default_factory=list)
full_text: str = ""
"""Concatenation of all page texts with \\n\\n separator. Fed into the
chunker for RAG."""
tables_markdown: str = ""
"""Tables extracted by pdfplumber, formatted as Markdown."""
table_count: int = 0
is_scanned: bool = False
"""True if at least one page is scanned and structured data can only be
extracted via the vision path."""
class Classification(BaseModel):
"""Output of the classify_node."""
doc_type: str
"""invoice | delivery_note | purchase_order | contract | financial_report | other"""
doc_type_display: str
"""Display label for the UI: 'Invoice', 'Contract', etc."""
confidence: float = Field(ge=0.0, le=1.0)
language: str = "en" # en | hu | de | fr | ...
used_vision: bool = False
"""True if classification was done via the vision-structured path (scanned doc)."""
class ExtractedData(BaseModel):
"""Output of the extract_subgraph for a single document.
The ``raw`` dict contains the JSON-schema payload (e.g. invoice.json fields).
The ``_quotes``, ``_confidence``, ``_source`` aliased fields are kept
SEPARATELY because they are anti-hallucination layers: domain checks read
``raw`` (typed names), but chat tools return the full ExtractedData JSON.
"""
raw: dict = Field(default_factory=dict)
quotes: list[str] = Field(default_factory=list, alias="_quotes")
confidence: dict = Field(default_factory=dict, alias="_confidence")
source: dict = Field(default_factory=dict, alias="_source")
model_config = {"populate_by_name": True}
class Risk(BaseModel):
"""A single risk / finding β every risk source uses this unified format."""
description: str
severity: str # high | medium | low | info
rationale: str = ""
kind: str # validation | domain_rule | plausibility | llm_analysis | cross_check
regulation: str | None = None
affected_document: str | None = None
source_check_id: str | None = None
"""For domain-check risks: which check generated this (debug + filtering)."""
class ProcessedDocument(BaseModel):
"""End-to-end result for a single document: ingest + classify + extract + risks."""
ingested: IngestedDocument
classification: Classification | None = None
extracted: ExtractedData | None = None
risks: list[Risk] = Field(default_factory=list)
"""Document-level risks (NOT routed into the global state['risks'] β
that one is centrally aggregated)."""
rag_chunks_indexed: int = 0
processing_seconds: float = 0.0
class ComparisonReport(BaseModel):
"""Output of three-way matching (compare_node).
The ``matches`` items are dict-shaped MatchResult records:
``{status, severity, message, field_name, expected, actual, source_a, source_b}``.
"""
invoice_filename: str | None = None
delivery_note_filename: str | None = None
purchase_order_filename: str | None = None
matches: list[dict] = Field(default_factory=list)
# Aggregated counters
total_checks: int = 0
ok_count: int = 0
warning_count: int = 0
critical_count: int = 0
missing_count: int = 0
overall_status: str = "ok" # ok | warning | critical | missing
summary: str = ""
# Forward-references for Phase 6 models
class DDPortfolioReport(BaseModel):
"""Forward stub for the Phase 6 DD assistant output."""
contract_count: int = 0
contracts: list[dict] = Field(default_factory=list)
total_monthly_obligations: dict[str, float] = Field(default_factory=dict)
expiring_soon: list[str] = Field(default_factory=list)
high_risk_contracts: list[str] = Field(default_factory=list)
top_red_flags: list[str] = Field(default_factory=list)
executive_summary: str = ""
specialist_outputs: dict = Field(default_factory=dict)
class PackageInsights(BaseModel):
"""Forward stub for the Phase 6 package insights output."""
executive_summary: str = ""
findings: list[dict] = Field(default_factory=list)
key_observations: list[str] = Field(default_factory=list)
package_type: str = "general"
# ---------------------------------------------------------------------------
# Reducers for the Send API fan-in
# ---------------------------------------------------------------------------
def merge_doc_results(
left: list[ProcessedDocument],
right: list[ProcessedDocument],
) -> list[ProcessedDocument]:
"""Send fan-in: FIELD-LEVEL merge keyed by file_name.
If different per-doc Send fan-out nodes update separate fields of the same
document (e.g. classify_per_doc β classification, rag_index_per_doc β
rag_chunks_indexed), the reducer does NOT clobber already-set fields β it
only refreshes not-None new values.
The reducer is ASSOCIATIVE and PURE.
"""
by_name: dict[str, ProcessedDocument] = {
d.ingested.file_name: d for d in left if d.ingested
}
for d in right:
if d.ingested is None:
continue
existing = by_name.get(d.ingested.file_name)
if existing is None:
by_name[d.ingested.file_name] = d
continue
# Field-level merge: only NOT-NONE new values overwrite
update: dict = {}
if d.classification is not None:
update["classification"] = d.classification
if d.extracted is not None:
update["extracted"] = d.extracted
if d.risks:
update["risks"] = d.risks
if d.rag_chunks_indexed:
update["rag_chunks_indexed"] = d.rag_chunks_indexed
if d.processing_seconds:
update["processing_seconds"] = d.processing_seconds
if update:
by_name[d.ingested.file_name] = existing.model_copy(update=update)
return list(by_name.values())
def merge_risks(left: list[Risk], right: list[Risk]) -> list[Risk]:
"""Risk dedup keyed by description (mirrors the prototype-agentic _add_risk).
First occurrence wins (left order preserved). A risk duplicates iff the
exact same description string appears β common because comparison risks are
document-independent and a per-doc loop would re-add them each iteration.
"""
seen = {r.description for r in left}
out = list(left)
for r in right:
if r.description not in seen:
out.append(r)
seen.add(r.description)
return out
# ---------------------------------------------------------------------------
# PipelineState TypedDict β the full graph state
# ---------------------------------------------------------------------------
class PipelineState(TypedDict, total=False):
"""The main pipeline graph state. Every node reads/updates this.
``total=False`` indicates that all fields are optional (not all initialized
at START). Send API fan-out branches write back into ``documents`` and
``risks`` via the reducers above.
"""
# Input
files: list[tuple[str, bytes]]
"""[(file_name, file_bytes), ...] β fed in from the Streamlit upload."""
# Per-doc fan-out / fan-in (with reducers)
documents: Annotated[list[ProcessedDocument], merge_doc_results]
risks: Annotated[list[Risk], merge_risks]
# Aggregated outputs
comparison: ComparisonReport | None
report: dict
package_insights: PackageInsights | None
dd_report: DDPortfolioReport | None
# Timing / progress
started_at: datetime
finished_at: datetime
processing_seconds: float
progress_events: Annotated[list[str], add]
"""Each node tick appends a string (Streamlit progress bar feed)."""
|