paperhawk / nodes /pipeline /classify_node.py
Nándorfi Vince
Initial paperhawk push to HF Space (LFS for binaries)
7ff7119
"""classify_node — LLM-based classification for a single document.
Async node: input is a DocState-shaped dict (from the dispatch_classify Send),
output is ``{"documents": [pd_with_classification]}`` which the parent reducer
(merge_doc_results) merges into the matching ProcessedDocument.
Vision-aware: if the ingested document has ``is_scanned=True`` and
``image_bytes``, classification runs on the vision path (image-based LLM call).
Otherwise text-based.
Dummy mode: when ``settings.is_dummy`` we do NOT call the LLM — keyword
heuristics return a Classification (fast + reproducible, eval-friendly).
vLLM/Ollama mode: factory ``build_classify_node(llm)`` captures the LLM
Runnable in a closure and calls ``with_structured_output(Classification)``.
Vision-aware: for scanned docs we use the multimodal
``HumanMessage(content=[{type=image,...}, {type=text,...}])`` shape.
"""
from __future__ import annotations
import base64
import re
from langchain_core.messages import HumanMessage, SystemMessage
from config import settings
from graph.states.pipeline_state import (
Classification,
IngestedDocument,
ProcessedDocument,
)
# 6 doc_type categories + display label
_DOC_TYPE_DISPLAY = {
"invoice": "Invoice",
"delivery_note": "Delivery Note",
"purchase_order": "Purchase Order",
"contract": "Contract",
"financial_report": "Financial Report",
"other": "Other",
}
# Keyword heuristic for dummy mode (multilingual, with word-boundary tolerance).
# Order MATTERS — delivery_note must be checked before invoice (so "delivery
# note" doesn't accidentally match the invoice keyword in some texts).
_KEYWORD_RULES: list[tuple[str, re.Pattern[str]]] = [
("delivery_note", re.compile(
r"\b(delivery\s*note|shipping\s*note|szallitolev\w*|Lieferschein)", re.I)),
("purchase_order", re.compile(
r"\b(purchase\s*order|order\s*number|order\s*confirmation|"
r"megrendel\w*|Bestellung)", re.I)),
("contract", re.compile(
r"\b(contract|agreement|service\s*agreement|nda|"
r"non[-\s]?disclosure|szerzodes|szerzodest|titoktart\w*|"
r"kotber\w*|felmondas\w*|Vertrag)", re.I)),
("financial_report", re.compile(
r"\b(income\s*statement|profit.{0,5}loss|p&l|balance\s*sheet|"
r"cash\s*flow|financial\s*statement|"
r"eredmenykimut\w*|merleg|penzugyi|Bilanz|Gewinn-?\s*und\s*Verlustrechnung)", re.I)),
("invoice", re.compile(r"\b(invoice|tax\s*invoice|szamla\w*|sz\.szam|Rechnung)", re.I)),
]
# Simplified language detection (EN/HU/DE)
_LANG_INDICATORS = {
"en": re.compile(r"\b(the|and|or|of|is|invoice|contract|agreement)\b", re.I),
"hu": re.compile(r"\b(es|az|hogy|nem|van|szamla|szerzodes)\b", re.I),
"de": re.compile(r"\b(der|die|das|und|ist|rechnung|vertrag)\b", re.I),
}
def _detect_language(text: str) -> str:
"""Simple keyword-ratio language detection (default: en)."""
if not text:
return "en"
snippet = text[:5000].lower()
scores = {lang: len(pat.findall(snippet)) for lang, pat in _LANG_INDICATORS.items()}
best = max(scores.items(), key=lambda x: x[1])
return best[0] if best[1] >= 3 else "en"
def _classify_dummy(ingested: IngestedDocument) -> Classification:
"""Dummy classifier — keyword-based, < 1 ms."""
text = ingested.full_text or ""
file_name = ingested.file_name.replace("_", " ").replace("-", " ")
# File-name-based override (often the strongest hint)
for doc_type, pattern in _KEYWORD_RULES:
if pattern.search(file_name):
return Classification(
doc_type=doc_type,
doc_type_display=_DOC_TYPE_DISPLAY[doc_type],
confidence=0.85,
language=_detect_language(text),
used_vision=ingested.is_scanned,
)
# Text-based
for doc_type, pattern in _KEYWORD_RULES:
if pattern.search(text):
return Classification(
doc_type=doc_type,
doc_type_display=_DOC_TYPE_DISPLAY[doc_type],
confidence=0.7,
language=_detect_language(text),
used_vision=ingested.is_scanned,
)
# Fallback: other
return Classification(
doc_type="other",
doc_type_display=_DOC_TYPE_DISPLAY["other"],
confidence=0.5,
language=_detect_language(text),
used_vision=ingested.is_scanned,
)
# ---------------------------------------------------------------------------
# vLLM/Ollama LLM classification
# ---------------------------------------------------------------------------
_CLASSIFY_SYSTEM_PROMPT = """You are a document classifier. Categorize the uploaded document into ONE of:
invoice, delivery_note, purchase_order, contract, financial_report, other.
Work only from the document content; do not fabricate. Fill ``doc_type`` with the code
('invoice', 'delivery_note', 'purchase_order', 'contract', 'financial_report', 'other'),
and ``doc_type_display`` with the display label ('Invoice', 'Delivery Note',
'Purchase Order', 'Contract', 'Financial Report', 'Other'). ``confidence`` is a
float between 0.0 and 1.0. ``language`` is the document language ('en', 'hu', 'de'),
default 'en'. ``used_vision`` is always False (the system fills it in)."""
async def _classify_llm_text(
structured_llm, ingested: IngestedDocument
) -> Classification:
"""Text-based LLM classification (with_structured_output)."""
text_preview = (ingested.full_text or "")[:3000]
user_prompt = f"Classify the following document by type:\n\n{text_preview}"
response = await structured_llm.ainvoke([
SystemMessage(content=_CLASSIFY_SYSTEM_PROMPT),
HumanMessage(content=user_prompt),
])
if isinstance(response, Classification):
response.used_vision = False
return response
return Classification(**response.model_dump()) if hasattr(response, "model_dump") else Classification(**dict(response))
async def _classify_llm_vision(
structured_llm, ingested: IngestedDocument
) -> Classification:
"""Vision-based LLM classification — sends the first page image."""
if not ingested.pages or not ingested.pages[0].image_bytes:
# No image → fall back to text path
return await _classify_llm_text(structured_llm, ingested)
img_b64 = base64.standard_b64encode(ingested.pages[0].image_bytes).decode("ascii")
msg = HumanMessage(content=[
{"type": "text", "text": "What kind of business document is shown in this image? Classify it."},
{
"type": "image",
"source_type": "base64",
"data": img_b64,
"mime_type": "image/png",
},
])
response = await structured_llm.ainvoke([
SystemMessage(content=_CLASSIFY_SYSTEM_PROMPT),
msg,
])
if isinstance(response, Classification):
response.used_vision = True
return response
obj = response.model_dump() if hasattr(response, "model_dump") else dict(response)
obj["used_vision"] = True
return Classification(**obj)
def build_classify_node(llm=None):
"""Factory: per-doc classify node.
Args:
llm: A BaseChatModel-like Runnable (vLLM/Ollama/Dummy). If None or
dummy mode, the regex-based heuristic runs.
"""
structured_llm = None
if llm is not None and not settings.is_dummy:
structured_llm = llm.with_structured_output(Classification)
async def classify_node(state: dict) -> dict:
ingested: IngestedDocument | None = state.get("ingested")
if ingested is None:
return {}
if settings.is_dummy or structured_llm is None:
classification = _classify_dummy(ingested)
else:
try:
if ingested.is_scanned:
classification = await _classify_llm_vision(structured_llm, ingested)
else:
classification = await _classify_llm_text(structured_llm, ingested)
# Display normalization: if the LLM returns something unknown
if classification.doc_type not in _DOC_TYPE_DISPLAY:
classification.doc_type = "other"
if classification.doc_type_display not in _DOC_TYPE_DISPLAY.values():
classification.doc_type_display = _DOC_TYPE_DISPLAY[classification.doc_type]
except Exception:
# LLM error (rate limit, network, schema fail) — fallback to dummy
classification = _classify_dummy(ingested)
pd = ProcessedDocument(ingested=ingested, classification=classification)
return {"documents": [pd]}
return classify_node
# Legacy backward-compat name (dummy mode) — works without the build factory
async def classify_node(state: dict) -> dict:
"""Legacy signature (dummy mode): equivalent to build_classify_node(None)()."""
return await build_classify_node(None)(state)