File size: 8,986 Bytes
7ff7119 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | """classify_node β LLM-based classification for a single document.
Async node: input is a DocState-shaped dict (from the dispatch_classify Send),
output is ``{"documents": [pd_with_classification]}`` which the parent reducer
(merge_doc_results) merges into the matching ProcessedDocument.
Vision-aware: if the ingested document has ``is_scanned=True`` and
``image_bytes``, classification runs on the vision path (image-based LLM call).
Otherwise text-based.
Dummy mode: when ``settings.is_dummy`` we do NOT call the LLM β keyword
heuristics return a Classification (fast + reproducible, eval-friendly).
vLLM/Ollama mode: factory ``build_classify_node(llm)`` captures the LLM
Runnable in a closure and calls ``with_structured_output(Classification)``.
Vision-aware: for scanned docs we use the multimodal
``HumanMessage(content=[{type=image,...}, {type=text,...}])`` shape.
"""
from __future__ import annotations
import base64
import re
from langchain_core.messages import HumanMessage, SystemMessage
from config import settings
from graph.states.pipeline_state import (
Classification,
IngestedDocument,
ProcessedDocument,
)
# 6 doc_type categories + display label
_DOC_TYPE_DISPLAY = {
"invoice": "Invoice",
"delivery_note": "Delivery Note",
"purchase_order": "Purchase Order",
"contract": "Contract",
"financial_report": "Financial Report",
"other": "Other",
}
# Keyword heuristic for dummy mode (multilingual, with word-boundary tolerance).
# Order MATTERS β delivery_note must be checked before invoice (so "delivery
# note" doesn't accidentally match the invoice keyword in some texts).
_KEYWORD_RULES: list[tuple[str, re.Pattern[str]]] = [
("delivery_note", re.compile(
r"\b(delivery\s*note|shipping\s*note|szallitolev\w*|Lieferschein)", re.I)),
("purchase_order", re.compile(
r"\b(purchase\s*order|order\s*number|order\s*confirmation|"
r"megrendel\w*|Bestellung)", re.I)),
("contract", re.compile(
r"\b(contract|agreement|service\s*agreement|nda|"
r"non[-\s]?disclosure|szerzodes|szerzodest|titoktart\w*|"
r"kotber\w*|felmondas\w*|Vertrag)", re.I)),
("financial_report", re.compile(
r"\b(income\s*statement|profit.{0,5}loss|p&l|balance\s*sheet|"
r"cash\s*flow|financial\s*statement|"
r"eredmenykimut\w*|merleg|penzugyi|Bilanz|Gewinn-?\s*und\s*Verlustrechnung)", re.I)),
("invoice", re.compile(r"\b(invoice|tax\s*invoice|szamla\w*|sz\.szam|Rechnung)", re.I)),
]
# Simplified language detection (EN/HU/DE)
_LANG_INDICATORS = {
"en": re.compile(r"\b(the|and|or|of|is|invoice|contract|agreement)\b", re.I),
"hu": re.compile(r"\b(es|az|hogy|nem|van|szamla|szerzodes)\b", re.I),
"de": re.compile(r"\b(der|die|das|und|ist|rechnung|vertrag)\b", re.I),
}
def _detect_language(text: str) -> str:
"""Simple keyword-ratio language detection (default: en)."""
if not text:
return "en"
snippet = text[:5000].lower()
scores = {lang: len(pat.findall(snippet)) for lang, pat in _LANG_INDICATORS.items()}
best = max(scores.items(), key=lambda x: x[1])
return best[0] if best[1] >= 3 else "en"
def _classify_dummy(ingested: IngestedDocument) -> Classification:
"""Dummy classifier β keyword-based, < 1 ms."""
text = ingested.full_text or ""
file_name = ingested.file_name.replace("_", " ").replace("-", " ")
# File-name-based override (often the strongest hint)
for doc_type, pattern in _KEYWORD_RULES:
if pattern.search(file_name):
return Classification(
doc_type=doc_type,
doc_type_display=_DOC_TYPE_DISPLAY[doc_type],
confidence=0.85,
language=_detect_language(text),
used_vision=ingested.is_scanned,
)
# Text-based
for doc_type, pattern in _KEYWORD_RULES:
if pattern.search(text):
return Classification(
doc_type=doc_type,
doc_type_display=_DOC_TYPE_DISPLAY[doc_type],
confidence=0.7,
language=_detect_language(text),
used_vision=ingested.is_scanned,
)
# Fallback: other
return Classification(
doc_type="other",
doc_type_display=_DOC_TYPE_DISPLAY["other"],
confidence=0.5,
language=_detect_language(text),
used_vision=ingested.is_scanned,
)
# ---------------------------------------------------------------------------
# vLLM/Ollama LLM classification
# ---------------------------------------------------------------------------
_CLASSIFY_SYSTEM_PROMPT = """You are a document classifier. Categorize the uploaded document into ONE of:
invoice, delivery_note, purchase_order, contract, financial_report, other.
Work only from the document content; do not fabricate. Fill ``doc_type`` with the code
('invoice', 'delivery_note', 'purchase_order', 'contract', 'financial_report', 'other'),
and ``doc_type_display`` with the display label ('Invoice', 'Delivery Note',
'Purchase Order', 'Contract', 'Financial Report', 'Other'). ``confidence`` is a
float between 0.0 and 1.0. ``language`` is the document language ('en', 'hu', 'de'),
default 'en'. ``used_vision`` is always False (the system fills it in)."""
async def _classify_llm_text(
structured_llm, ingested: IngestedDocument
) -> Classification:
"""Text-based LLM classification (with_structured_output)."""
text_preview = (ingested.full_text or "")[:3000]
user_prompt = f"Classify the following document by type:\n\n{text_preview}"
response = await structured_llm.ainvoke([
SystemMessage(content=_CLASSIFY_SYSTEM_PROMPT),
HumanMessage(content=user_prompt),
])
if isinstance(response, Classification):
response.used_vision = False
return response
return Classification(**response.model_dump()) if hasattr(response, "model_dump") else Classification(**dict(response))
async def _classify_llm_vision(
structured_llm, ingested: IngestedDocument
) -> Classification:
"""Vision-based LLM classification β sends the first page image."""
if not ingested.pages or not ingested.pages[0].image_bytes:
# No image β fall back to text path
return await _classify_llm_text(structured_llm, ingested)
img_b64 = base64.standard_b64encode(ingested.pages[0].image_bytes).decode("ascii")
msg = HumanMessage(content=[
{"type": "text", "text": "What kind of business document is shown in this image? Classify it."},
{
"type": "image",
"source_type": "base64",
"data": img_b64,
"mime_type": "image/png",
},
])
response = await structured_llm.ainvoke([
SystemMessage(content=_CLASSIFY_SYSTEM_PROMPT),
msg,
])
if isinstance(response, Classification):
response.used_vision = True
return response
obj = response.model_dump() if hasattr(response, "model_dump") else dict(response)
obj["used_vision"] = True
return Classification(**obj)
def build_classify_node(llm=None):
"""Factory: per-doc classify node.
Args:
llm: A BaseChatModel-like Runnable (vLLM/Ollama/Dummy). If None or
dummy mode, the regex-based heuristic runs.
"""
structured_llm = None
if llm is not None and not settings.is_dummy:
structured_llm = llm.with_structured_output(Classification)
async def classify_node(state: dict) -> dict:
ingested: IngestedDocument | None = state.get("ingested")
if ingested is None:
return {}
if settings.is_dummy or structured_llm is None:
classification = _classify_dummy(ingested)
else:
try:
if ingested.is_scanned:
classification = await _classify_llm_vision(structured_llm, ingested)
else:
classification = await _classify_llm_text(structured_llm, ingested)
# Display normalization: if the LLM returns something unknown
if classification.doc_type not in _DOC_TYPE_DISPLAY:
classification.doc_type = "other"
if classification.doc_type_display not in _DOC_TYPE_DISPLAY.values():
classification.doc_type_display = _DOC_TYPE_DISPLAY[classification.doc_type]
except Exception:
# LLM error (rate limit, network, schema fail) β fallback to dummy
classification = _classify_dummy(ingested)
pd = ProcessedDocument(ingested=ingested, classification=classification)
return {"documents": [pd]}
return classify_node
# Legacy backward-compat name (dummy mode) β works without the build factory
async def classify_node(state: dict) -> dict:
"""Legacy signature (dummy mode): equivalent to build_classify_node(None)()."""
return await build_classify_node(None)(state)
|