File size: 8,476 Bytes
7ff7119 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 | """risk_subgraph β aggregated risk analysis with Send API parallelism.
Topology:
START
β basic_risk_dispatch (Send: per-doc basic risk)
β basic_risk / noop_basic
β domain_dispatch_node (Send: per-doc Γ per-applicable-check, ~30 parallel)
β apply_domain_check
β [if llm provided] llm_risk_dispatch (Send: per-doc LLM risk + 3-filter chain)
β llm_risk_per_doc / noop_llm
β plausibility_dispatch (Send: per-doc plausibility)
β plausibility / noop_plaus
β evidence_score_node (per-doc info)
β duplicate_detector_node (package-level, sync, ISA 240)
END
If ``llm=None``, the LLM risk-analysis layer is skipped (Phase-4 backward
compatible). When ``llm`` is provided, the ``llm_risk_subgraph`` runs a 4-node
chain per-doc with Send fan-out: llm_risk β filter_llm_risks β
drop_business_normal β drop_repeats. The full anti-hallucination 5+1 layers.
"""
from __future__ import annotations
from langgraph.graph import END, START, StateGraph
from langgraph.types import Send
from graph.states.pipeline_state import PipelineState, ProcessedDocument, Risk
from nodes.pipeline.duplicate_detector_node import duplicate_detector_node
from nodes.risk.basic_risk_node import basic_risk_node
from nodes.risk.domain_dispatch_node import (
apply_domain_check_node,
domain_dispatch_node,
)
from nodes.risk.evidence_score_node import evidence_score_node
from nodes.risk.plausibility_node import plausibility_node
from subgraphs.llm_risk_subgraph import build_llm_risk_subgraph
# ---------------------------------------------------------------------------
# Send dispatchers (basic + plausibility per-doc)
# ---------------------------------------------------------------------------
def basic_risk_dispatch(state: PipelineState) -> list[Send]:
sends: list[Send] = []
documents: list[ProcessedDocument] = state.get("documents") or []
for doc in documents:
if doc.classification is None or doc.extracted is None:
continue
sends.append(Send("basic_risk", {
"doc_file_name": doc.ingested.file_name,
"doc_type": doc.classification.doc_type,
"extracted": doc.extracted.raw,
}))
return sends or [Send("noop_basic", {})]
def plausibility_dispatch(state: PipelineState) -> list[Send]:
sends: list[Send] = []
documents: list[ProcessedDocument] = state.get("documents") or []
for doc in documents:
if doc.classification is None or doc.extracted is None:
continue
sends.append(Send("plausibility", {
"doc_file_name": doc.ingested.file_name,
"extracted": doc.extracted.raw,
}))
return sends or [Send("noop_plaus", {})]
def llm_risk_dispatch(state: PipelineState) -> list[Send]:
"""Per-doc Send to the ``llm_risk_per_doc`` node.
We pass the per-doc-filtered basic + domain + plausibility risks so the
``llm_risk_node`` can build the "ALREADY FOUND" block, and so
``drop_repeats_node`` doesn't drop genuinely new observations.
Filtering is by ``Risk.affected_document`` field.
"""
sends: list[Send] = []
documents: list[ProcessedDocument] = state.get("documents") or []
all_risks: list[Risk] = state.get("risks") or []
for doc in documents:
if doc.classification is None or doc.extracted is None:
continue
file_name = doc.ingested.file_name
# Filter risks for this doc by affected_document.
# We also include risks with affected_document=None (e.g. global
# duplicate detection) since they don't disturb per-doc context.
per_doc_basic = [
r for r in all_risks
if r.affected_document is None or r.affected_document == file_name
]
sends.append(Send("llm_risk_per_doc", {
"doc_file_name": file_name,
"extracted": doc.extracted.raw,
"basic_risks": per_doc_basic,
}))
return sends or [Send("noop_llm", {})]
async def _noop_basic(state: dict) -> dict:
return {}
async def _noop_plaus(state: dict) -> dict:
return {}
async def _noop_llm(state: dict) -> dict:
return {}
# ---------------------------------------------------------------------------
# Subgraph builder
# ---------------------------------------------------------------------------
def build_risk_subgraph(llm=None):
"""Compile the risk subgraph (operates on the parent PipelineState).
Args:
llm: optional BaseChatModel-like Runnable. If None, the LLM
risk-analysis layer (assess_risks_llm + 3 filters) is SKIPPED;
only basic + domain + plausibility + evidence_score +
duplicate_detector run (Phase-4 backward-compatible mode). If
provided, the LLM layer runs after domain_dispatch.
"""
graph = StateGraph(PipelineState)
# Domain-dispatch + apply (Send fan-out for 12 of the 14 checks)
graph.add_node("domain_dispatcher", _domain_dispatcher_passthrough)
graph.add_node("apply_domain_check", apply_domain_check_node)
# Basic risk (per-doc fan-out)
graph.add_node("basic_risk_dispatcher", _basic_dispatcher_passthrough)
graph.add_node("basic_risk", basic_risk_node)
graph.add_node("noop_basic", _noop_basic)
# Plausibility (per-doc fan-out)
graph.add_node("plausibility_dispatcher", _plaus_dispatcher_passthrough)
graph.add_node("plausibility", plausibility_node)
graph.add_node("noop_plaus", _noop_plaus)
# Per-doc info (evidence score)
graph.add_node("evidence_score", evidence_score_node)
# Package-level duplicate
graph.add_node("duplicate_detector", duplicate_detector_node)
# LLM risk subgraph (if llm provided) β Send fan-out per-doc chain
has_llm = llm is not None
if has_llm:
llm_risk_subgraph = build_llm_risk_subgraph(llm)
async def llm_risk_per_doc(state: dict) -> dict:
"""Run the LLM risk subgraph on the parent Send payload.
At the end of the subgraph the 3-filter result is in ``risks``;
it merges into the parent state's ``risks`` reducer.
"""
result = await llm_risk_subgraph.ainvoke(state)
risks = result.get("risks") or []
return {"risks": risks} if risks else {}
graph.add_node("llm_risk_dispatcher", _llm_risk_dispatcher_passthrough)
graph.add_node("llm_risk_per_doc", llm_risk_per_doc)
graph.add_node("noop_llm", _noop_llm)
# Edges: dispatchers β conditional Sends β join nodes
graph.add_edge(START, "basic_risk_dispatcher")
graph.add_conditional_edges(
"basic_risk_dispatcher",
basic_risk_dispatch,
["basic_risk", "noop_basic"],
)
graph.add_edge("basic_risk", "domain_dispatcher")
graph.add_edge("noop_basic", "domain_dispatcher")
graph.add_conditional_edges(
"domain_dispatcher",
domain_dispatch_node,
["apply_domain_check"],
)
if has_llm:
# apply_domain_check β llm_risk_dispatcher β llm_risk_per_doc β plausibility_dispatcher
graph.add_edge("apply_domain_check", "llm_risk_dispatcher")
graph.add_conditional_edges(
"llm_risk_dispatcher",
llm_risk_dispatch,
["llm_risk_per_doc", "noop_llm"],
)
graph.add_edge("llm_risk_per_doc", "plausibility_dispatcher")
graph.add_edge("noop_llm", "plausibility_dispatcher")
else:
# apply_domain_check β plausibility_dispatcher (skip LLM)
graph.add_edge("apply_domain_check", "plausibility_dispatcher")
graph.add_conditional_edges(
"plausibility_dispatcher",
plausibility_dispatch,
["plausibility", "noop_plaus"],
)
graph.add_edge("plausibility", "evidence_score")
graph.add_edge("noop_plaus", "evidence_score")
graph.add_edge("evidence_score", "duplicate_detector")
graph.add_edge("duplicate_detector", END)
return graph.compile()
# Passthrough nodes (combined with Send dispatchers for fan-out)
async def _domain_dispatcher_passthrough(state: PipelineState) -> dict:
return {}
async def _basic_dispatcher_passthrough(state: PipelineState) -> dict:
return {}
async def _plaus_dispatcher_passthrough(state: PipelineState) -> dict:
return {}
async def _llm_risk_dispatcher_passthrough(state: PipelineState) -> dict:
return {}
|