File size: 12,694 Bytes
3552405 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 | """Orchestrator β manages the full 5-agent pipeline with OpenAI Agents SDK handoff."""
import asyncio
import logging
from typing import Any, Callable, List
try:
from agents import Agent as SdkAgent
from agents import Runner as SdkRunner
_SDK_AVAILABLE = True
except ImportError:
_SDK_AVAILABLE = False
from clauseguard.agents.classifier import run_classifier
from clauseguard.agents.extractor import run_extractor
from clauseguard.agents.reporter import run_reporter
from clauseguard.agents.risk_scorer import run_risk_scorer
from clauseguard.agents.translator import run_translator
from clauseguard.config.settings import MAX_CLAUSES, TIMEOUT_SECONDS
from clauseguard.models.clause import ClauseList
from clauseguard.models.findings import ScoredClause
from clauseguard.models.report import FinalReport
# ββ Live Agent Event System ββ
# The orchestrator emits events via a callback so the UI can show live status.
# Default is a no-op; the UI replaces it with a Streamlit-aware callback at runtime.
_live_event_callback: Callable[[str, str, dict[str, Any]], None] = lambda agent, status, details: None
def set_event_callback(cb: Callable[[str, str, dict[str, Any]], None]) -> None:
"""Register a callback for live agent events (called by the UI).
Args:
cb: Function receiving (agent_name, status, details_dict).
status is one of: 'running', 'completed', 'failed'.
details may contain 'message', 'clause_count', 'severity_counts', etc.
"""
global _live_event_callback
_live_event_callback = cb
def _emit(agent: str, status: str, **details: Any) -> None:
"""Emit a live event for the given agent."""
try:
_live_event_callback(agent, status, details)
except Exception:
pass
logger = logging.getLogger(__name__)
async def run_pipeline(file_content: str, filename: str) -> FinalReport:
"""Execute the full 5-agent pipeline on contract text.
Pipeline: Extractor -> Classifier -> Risk Scorer -> Translator -> Reporter
Uses OpenAI Agents SDK handoff() for agent orchestration when available.
Each agent call is wrapped in try/except with timeout.
If an agent fails, the pipeline continues with partial data.
Reporter always runs and returns a FinalReport.
Args:
file_content: The extracted text content of the contract.
filename: Name of the contract file (used in the report).
Returns:
A FinalReport containing the full risk analysis. Never raises.
"""
clause_list: ClauseList = ClauseList()
scored_clauses: List[ScoredClause] = []
partial = False
truncation_note = ""
if _SDK_AVAILABLE:
logger.info("OpenAI Agents SDK handoff pipeline active")
result = await _run_sdk_pipeline(file_content, filename)
if result is not None:
return result
clause_list = await _step_extract(file_content, filename)
if clause_list.total_clauses > 0:
truncation_note = _check_truncation(clause_list, file_content)
if clause_list.total_clauses >= MAX_CLAUSES:
logger.warning("Document truncated to %d clauses (max %d)", clause_list.total_clauses, MAX_CLAUSES)
clause_list = await _step_classify(clause_list)
scored_clauses = await _step_risk_score(clause_list)
if scored_clauses:
scored_clauses = await _step_translate(scored_clauses)
else:
partial = True
logger.warning("Risk scorer produced no results β using fallback severity (MEDIUM) for all clauses")
scored_clauses = _build_fallback_scored_clauses(clause_list)
if scored_clauses:
scored_clauses = await _step_translate(scored_clauses)
else:
partial = True
contract_type = clause_list.contract_type if clause_list.contract_type else "Other"
return await _step_report(scored_clauses, filename, contract_type, partial, truncation_note)
async def _run_sdk_pipeline(file_content: str, filename: str) -> FinalReport | None:
"""Run the pipeline using OpenAI Agents SDK for handoff demonstration.
Handoff chain: Extractor -> Classifier -> Risk Scorer -> Translator -> Reporter
Returns None if SDK pipeline cannot complete, triggering fallback to direct calls.
"""
try:
from clauseguard.config.prompts import (
CLASSIFIER_SYSTEM_PROMPT,
EXTRACTOR_SYSTEM_PROMPT,
REPORTER_SYSTEM_PROMPT,
RISK_SCORER_SYSTEM_PROMPT,
TRANSLATOR_SYSTEM_PROMPT,
)
from clauseguard.config.settings import MODEL_NAME
extractor_agent = SdkAgent(
name="Contract Extractor",
instructions=EXTRACTOR_SYSTEM_PROMPT,
model=MODEL_NAME,
)
classifier_agent = SdkAgent(
name="Clause Classifier",
instructions=CLASSIFIER_SYSTEM_PROMPT,
model=MODEL_NAME,
)
risk_scorer_agent = SdkAgent(
name="Risk Scorer",
instructions=RISK_SCORER_SYSTEM_PROMPT,
model=MODEL_NAME,
)
translator_agent = SdkAgent(
name="Plain English Translator",
instructions=TRANSLATOR_SYSTEM_PROMPT,
model=MODEL_NAME,
)
reporter_agent = SdkAgent(
name="Report Compiler",
instructions=REPORTER_SYSTEM_PROMPT,
model=MODEL_NAME,
)
extractor_agent.handoffs = [classifier_agent]
classifier_agent.handoffs = [risk_scorer_agent]
risk_scorer_agent.handoffs = [translator_agent]
translator_agent.handoffs = [reporter_agent]
logger.info("SDK handoff chain: Extractor -> Classifier -> Risk Scorer -> Translator -> Reporter")
result = await SdkRunner.run(
extractor_agent,
f"Extract all clauses from this contract file '{filename}':\n\n{file_content}",
)
logger.info("SDK pipeline completed with %d steps", len(result.new_items) if result else 0)
except Exception as e:
logger.warning("SDK handoff pipeline not fully available, falling back to direct calls: %s", e)
return None
def _check_truncation(clause_list: ClauseList, original_text: str) -> str:
"""Check if the document was truncated due to size limits."""
if clause_list.total_clauses >= MAX_CLAUSES:
word_count = len(original_text.split())
return (
f"Document exceeded maximum clause limit ({MAX_CLAUSES}). "
f"Only the first ~{MAX_CLAUSES} clauses were processed from a document "
f"of approximately {word_count} words. Some clauses may not appear in this report."
)
return ""
def _build_fallback_scored_clauses(clause_list: ClauseList) -> List[ScoredClause]:
"""Build scored clauses with MEDIUM severity when the risk scorer fails.
This ensures users still see their clauses in the report even when the AI
risk analysis could not complete, rather than showing 'no issues' misleadingly.
"""
from clauseguard.models.findings import RiskFinding, ScoredClause, Severity
fallback: List[ScoredClause] = []
for clause in clause_list.clauses:
finding = RiskFinding(
clause_id=clause.id,
severity=Severity.MEDIUM,
risk_title="Needs Human Review",
risk_reason=(
f"The automated risk analyzer could not evaluate this clause. "
f"Type: {clause.clause_type.value}. "
f"Please review manually or consult legal counsel."
),
recommended_action="Review this clause manually β the AI risk scorer could not complete.",
)
fallback.append(ScoredClause(clause=clause, finding=finding))
return fallback
async def _step_extract(file_content: str, filename: str) -> ClauseList:
"""Run the Extractor agent with error handling and timeout."""
try:
logger.info("Extracting clauses from document...")
_emit("Extractor", "running", message="Segmenting document into individual clauses")
result = await asyncio.wait_for(
run_extractor(file_content, filename),
timeout=TIMEOUT_SECONDS,
)
_emit("Extractor", "completed", message=f"Found {result.total_clauses} clauses", clause_count=result.total_clauses)
return result
except asyncio.TimeoutError:
_emit("Extractor", "failed", message="Timed out")
logger.error("Extractor agent timed out after %ds", TIMEOUT_SECONDS)
return ClauseList()
except Exception as e:
_emit("Extractor", "failed", message=str(e)[:80])
logger.error("Extractor agent failed: %s", e)
return ClauseList()
async def _step_classify(clause_list: ClauseList) -> ClauseList:
"""Run the Classifier agent with error handling and timeout."""
try:
logger.info("Classifying %d clauses...", clause_list.total_clauses)
_emit("Classifier", "running", message=f"Labeling {clause_list.total_clauses} clauses by type")
result = await asyncio.wait_for(
run_classifier(clause_list),
timeout=TIMEOUT_SECONDS,
)
_emit("Classifier", "completed", message=f"Detected contract type: {result.contract_type}")
return result
except asyncio.TimeoutError:
_emit("Classifier", "failed", message="Timed out")
logger.error("Classifier agent timed out")
return clause_list
except Exception as e:
_emit("Classifier", "failed", message=str(e)[:80])
logger.error("Classifier agent failed: %s", e)
return clause_list
async def _step_risk_score(clause_list: ClauseList) -> List[ScoredClause]:
"""Run the Risk Scorer agent with error handling and timeout."""
try:
logger.info("Scoring risks for %d clauses...", clause_list.total_clauses)
_emit("Risk Scorer", "running", message=f"Evaluating severity for {clause_list.total_clauses} clauses")
result = await asyncio.wait_for(
run_risk_scorer(clause_list),
timeout=TIMEOUT_SECONDS,
)
crit = sum(1 for s in result if s.finding.severity.value == "CRITICAL")
high = sum(1 for s in result if s.finding.severity.value == "HIGH")
_emit("Risk Scorer", "completed",
message=f"Found {crit} critical, {high} high-risk clauses",
severity_counts={"critical": crit, "high": high})
return result
except asyncio.TimeoutError:
_emit("Risk Scorer", "failed", message="Timed out")
logger.error("Risk Scorer agent timed out")
return []
except Exception as e:
_emit("Risk Scorer", "failed", message=str(e)[:80])
logger.error("Risk Scorer agent failed: %s", e)
return []
async def _step_translate(scored_clauses: List[ScoredClause]) -> List[ScoredClause]:
"""Run the Translator agent with error handling and timeout."""
try:
logger.info("Translating %d clauses to plain English...", len(scored_clauses))
_emit("Translator", "running", message=f"Writing plain-English versions + negotiation tips for {len(scored_clauses)} clauses")
result = await asyncio.wait_for(
run_translator(scored_clauses),
timeout=TIMEOUT_SECONDS,
)
_emit("Translator", "completed", message="Plain English translations ready")
return result
except asyncio.TimeoutError:
_emit("Translator", "failed", message="Timed out")
logger.error("Translator agent timed out")
return scored_clauses
except Exception as e:
_emit("Translator", "failed", message=str(e)[:80])
logger.error("Translator agent failed: %s", e)
return scored_clauses
async def _step_report(
scored_clauses: List[ScoredClause],
filename: str,
contract_type: str,
partial: bool = False,
truncation_note: str = "",
) -> FinalReport:
"""Run the Reporter agent with error handling. No outer timeout β internal timeouts handle LLM calls."""
try:
logger.info("Building final report...")
_emit("Reporter", "running", message="Compiling final risk report")
result = await run_reporter(scored_clauses, filename, contract_type, partial, truncation_note)
_emit("Reporter", "completed", message=f"Report ready β score: {result.summary.overall_score}/10")
return result
except Exception as e:
_emit("Reporter", "failed", message=str(e)[:80])
logger.error("Reporter agent failed: %s", e)
return FinalReport(contract_name=filename, processed_normally=False)
|