muhammadbinmurtza
Restructure: clauseguard as package subfolder, app_file: clauseguard/app.py
913a064 | """Orchestrator — manages the full 5-agent pipeline with OpenAI Agents SDK handoff.""" | |
| import asyncio | |
| import logging | |
| from typing import Any, Callable, List | |
| try: | |
| from agents import Agent as SdkAgent | |
| from agents import Runner as SdkRunner | |
| _SDK_AVAILABLE = True | |
| except ImportError: | |
| _SDK_AVAILABLE = False | |
| from clauseguard.agents.classifier import run_classifier | |
| from clauseguard.agents.extractor import run_extractor | |
| from clauseguard.agents.reporter import run_reporter | |
| from clauseguard.agents.risk_scorer import run_risk_scorer | |
| from clauseguard.agents.translator import run_translator | |
| from clauseguard.config.settings import MAX_CLAUSES, TIMEOUT_SECONDS | |
| from clauseguard.models.clause import ClauseList | |
| from clauseguard.models.findings import ScoredClause | |
| from clauseguard.models.report import FinalReport | |
| # ── Live Agent Event System ── | |
| # The orchestrator emits events via a callback so the UI can show live status. | |
| # Default is a no-op; the UI replaces it with a Streamlit-aware callback at runtime. | |
| _live_event_callback: Callable[[str, str, dict[str, Any]], None] = lambda agent, status, details: None | |
| def set_event_callback(cb: Callable[[str, str, dict[str, Any]], None]) -> None: | |
| """Register a callback for live agent events (called by the UI). | |
| Args: | |
| cb: Function receiving (agent_name, status, details_dict). | |
| status is one of: 'running', 'completed', 'failed'. | |
| details may contain 'message', 'clause_count', 'severity_counts', etc. | |
| """ | |
| global _live_event_callback | |
| _live_event_callback = cb | |
| def _emit(agent: str, status: str, **details: Any) -> None: | |
| """Emit a live event for the given agent.""" | |
| try: | |
| _live_event_callback(agent, status, details) | |
| except Exception: | |
| pass | |
| logger = logging.getLogger(__name__) | |
| async def run_pipeline(file_content: str, filename: str) -> FinalReport: | |
| """Execute the full 5-agent pipeline on contract text. | |
| Pipeline: Extractor -> Classifier -> Risk Scorer -> Translator -> Reporter | |
| Uses OpenAI Agents SDK handoff() for agent orchestration when available. | |
| Each agent call is wrapped in try/except with timeout. | |
| If an agent fails, the pipeline continues with partial data. | |
| Reporter always runs and returns a FinalReport. | |
| Args: | |
| file_content: The extracted text content of the contract. | |
| filename: Name of the contract file (used in the report). | |
| Returns: | |
| A FinalReport containing the full risk analysis. Never raises. | |
| """ | |
| clause_list: ClauseList = ClauseList() | |
| scored_clauses: List[ScoredClause] = [] | |
| partial = False | |
| truncation_note = "" | |
| if _SDK_AVAILABLE: | |
| logger.info("OpenAI Agents SDK handoff pipeline active") | |
| result = await _run_sdk_pipeline(file_content, filename) | |
| if result is not None: | |
| return result | |
| clause_list = await _step_extract(file_content, filename) | |
| if clause_list.total_clauses > 0: | |
| truncation_note = _check_truncation(clause_list, file_content) | |
| if clause_list.total_clauses >= MAX_CLAUSES: | |
| logger.warning("Document truncated to %d clauses (max %d)", clause_list.total_clauses, MAX_CLAUSES) | |
| clause_list = await _step_classify(clause_list) | |
| scored_clauses = await _step_risk_score(clause_list) | |
| if scored_clauses: | |
| scored_clauses = await _step_translate(scored_clauses) | |
| else: | |
| partial = True | |
| logger.warning("Risk scorer produced no results — using fallback severity (MEDIUM) for all clauses") | |
| scored_clauses = _build_fallback_scored_clauses(clause_list) | |
| if scored_clauses: | |
| scored_clauses = await _step_translate(scored_clauses) | |
| else: | |
| partial = True | |
| contract_type = clause_list.contract_type if clause_list.contract_type else "Other" | |
| return await _step_report(scored_clauses, filename, contract_type, partial, truncation_note) | |
| async def _run_sdk_pipeline(file_content: str, filename: str) -> FinalReport | None: | |
| """Run the pipeline using OpenAI Agents SDK for handoff demonstration. | |
| Handoff chain: Extractor -> Classifier -> Risk Scorer -> Translator -> Reporter | |
| Returns None if SDK pipeline cannot complete, triggering fallback to direct calls. | |
| """ | |
| try: | |
| from clauseguard.config.prompts import ( | |
| CLASSIFIER_SYSTEM_PROMPT, | |
| EXTRACTOR_SYSTEM_PROMPT, | |
| REPORTER_SYSTEM_PROMPT, | |
| RISK_SCORER_SYSTEM_PROMPT, | |
| TRANSLATOR_SYSTEM_PROMPT, | |
| ) | |
| from clauseguard.config.settings import MODEL_NAME | |
| extractor_agent = SdkAgent( | |
| name="Contract Extractor", | |
| instructions=EXTRACTOR_SYSTEM_PROMPT, | |
| model=MODEL_NAME, | |
| ) | |
| classifier_agent = SdkAgent( | |
| name="Clause Classifier", | |
| instructions=CLASSIFIER_SYSTEM_PROMPT, | |
| model=MODEL_NAME, | |
| ) | |
| risk_scorer_agent = SdkAgent( | |
| name="Risk Scorer", | |
| instructions=RISK_SCORER_SYSTEM_PROMPT, | |
| model=MODEL_NAME, | |
| ) | |
| translator_agent = SdkAgent( | |
| name="Plain English Translator", | |
| instructions=TRANSLATOR_SYSTEM_PROMPT, | |
| model=MODEL_NAME, | |
| ) | |
| reporter_agent = SdkAgent( | |
| name="Report Compiler", | |
| instructions=REPORTER_SYSTEM_PROMPT, | |
| model=MODEL_NAME, | |
| ) | |
| extractor_agent.handoffs = [classifier_agent] | |
| classifier_agent.handoffs = [risk_scorer_agent] | |
| risk_scorer_agent.handoffs = [translator_agent] | |
| translator_agent.handoffs = [reporter_agent] | |
| logger.info("SDK handoff chain: Extractor -> Classifier -> Risk Scorer -> Translator -> Reporter") | |
| result = await SdkRunner.run( | |
| extractor_agent, | |
| f"Extract all clauses from this contract file '{filename}':\n\n{file_content}", | |
| ) | |
| logger.info("SDK pipeline completed with %d steps", len(result.new_items) if result else 0) | |
| except Exception as e: | |
| logger.warning("SDK handoff pipeline not fully available, falling back to direct calls: %s", e) | |
| return None | |
| def _check_truncation(clause_list: ClauseList, original_text: str) -> str: | |
| """Check if the document was truncated due to size limits.""" | |
| if clause_list.total_clauses >= MAX_CLAUSES: | |
| word_count = len(original_text.split()) | |
| return ( | |
| f"Document exceeded maximum clause limit ({MAX_CLAUSES}). " | |
| f"Only the first ~{MAX_CLAUSES} clauses were processed from a document " | |
| f"of approximately {word_count} words. Some clauses may not appear in this report." | |
| ) | |
| return "" | |
| def _build_fallback_scored_clauses(clause_list: ClauseList) -> List[ScoredClause]: | |
| """Build scored clauses with MEDIUM severity when the risk scorer fails. | |
| This ensures users still see their clauses in the report even when the AI | |
| risk analysis could not complete, rather than showing 'no issues' misleadingly. | |
| """ | |
| from clauseguard.models.findings import RiskFinding, ScoredClause, Severity | |
| fallback: List[ScoredClause] = [] | |
| for clause in clause_list.clauses: | |
| finding = RiskFinding( | |
| clause_id=clause.id, | |
| severity=Severity.MEDIUM, | |
| risk_title="Needs Human Review", | |
| risk_reason=( | |
| f"The automated risk analyzer could not evaluate this clause. " | |
| f"Type: {clause.clause_type.value}. " | |
| f"Please review manually or consult legal counsel." | |
| ), | |
| recommended_action="Review this clause manually — the AI risk scorer could not complete.", | |
| ) | |
| fallback.append(ScoredClause(clause=clause, finding=finding)) | |
| return fallback | |
| async def _step_extract(file_content: str, filename: str) -> ClauseList: | |
| """Run the Extractor agent with error handling and timeout.""" | |
| try: | |
| logger.info("Extracting clauses from document...") | |
| _emit("Extractor", "running", message="Segmenting document into individual clauses") | |
| result = await asyncio.wait_for( | |
| run_extractor(file_content, filename), | |
| timeout=TIMEOUT_SECONDS, | |
| ) | |
| _emit("Extractor", "completed", message=f"Found {result.total_clauses} clauses", clause_count=result.total_clauses) | |
| return result | |
| except asyncio.TimeoutError: | |
| _emit("Extractor", "failed", message="Timed out") | |
| logger.error("Extractor agent timed out after %ds", TIMEOUT_SECONDS) | |
| return ClauseList() | |
| except Exception as e: | |
| _emit("Extractor", "failed", message=str(e)[:80]) | |
| logger.error("Extractor agent failed: %s", e) | |
| return ClauseList() | |
| async def _step_classify(clause_list: ClauseList) -> ClauseList: | |
| """Run the Classifier agent with error handling and timeout.""" | |
| try: | |
| logger.info("Classifying %d clauses...", clause_list.total_clauses) | |
| _emit("Classifier", "running", message=f"Labeling {clause_list.total_clauses} clauses by type") | |
| result = await asyncio.wait_for( | |
| run_classifier(clause_list), | |
| timeout=TIMEOUT_SECONDS, | |
| ) | |
| _emit("Classifier", "completed", message=f"Detected contract type: {result.contract_type}") | |
| return result | |
| except asyncio.TimeoutError: | |
| _emit("Classifier", "failed", message="Timed out") | |
| logger.error("Classifier agent timed out") | |
| return clause_list | |
| except Exception as e: | |
| _emit("Classifier", "failed", message=str(e)[:80]) | |
| logger.error("Classifier agent failed: %s", e) | |
| return clause_list | |
| async def _step_risk_score(clause_list: ClauseList) -> List[ScoredClause]: | |
| """Run the Risk Scorer agent with error handling and timeout.""" | |
| try: | |
| logger.info("Scoring risks for %d clauses...", clause_list.total_clauses) | |
| _emit("Risk Scorer", "running", message=f"Evaluating severity for {clause_list.total_clauses} clauses") | |
| result = await asyncio.wait_for( | |
| run_risk_scorer(clause_list), | |
| timeout=TIMEOUT_SECONDS, | |
| ) | |
| crit = sum(1 for s in result if s.finding.severity.value == "CRITICAL") | |
| high = sum(1 for s in result if s.finding.severity.value == "HIGH") | |
| _emit("Risk Scorer", "completed", | |
| message=f"Found {crit} critical, {high} high-risk clauses", | |
| severity_counts={"critical": crit, "high": high}) | |
| return result | |
| except asyncio.TimeoutError: | |
| _emit("Risk Scorer", "failed", message="Timed out") | |
| logger.error("Risk Scorer agent timed out") | |
| return [] | |
| except Exception as e: | |
| _emit("Risk Scorer", "failed", message=str(e)[:80]) | |
| logger.error("Risk Scorer agent failed: %s", e) | |
| return [] | |
| async def _step_translate(scored_clauses: List[ScoredClause]) -> List[ScoredClause]: | |
| """Run the Translator agent with error handling and timeout.""" | |
| try: | |
| logger.info("Translating %d clauses to plain English...", len(scored_clauses)) | |
| _emit("Translator", "running", message=f"Writing plain-English versions + negotiation tips for {len(scored_clauses)} clauses") | |
| result = await asyncio.wait_for( | |
| run_translator(scored_clauses), | |
| timeout=TIMEOUT_SECONDS, | |
| ) | |
| _emit("Translator", "completed", message="Plain English translations ready") | |
| return result | |
| except asyncio.TimeoutError: | |
| _emit("Translator", "failed", message="Timed out") | |
| logger.error("Translator agent timed out") | |
| return scored_clauses | |
| except Exception as e: | |
| _emit("Translator", "failed", message=str(e)[:80]) | |
| logger.error("Translator agent failed: %s", e) | |
| return scored_clauses | |
| async def _step_report( | |
| scored_clauses: List[ScoredClause], | |
| filename: str, | |
| contract_type: str, | |
| partial: bool = False, | |
| truncation_note: str = "", | |
| ) -> FinalReport: | |
| """Run the Reporter agent with error handling. No outer timeout — internal timeouts handle LLM calls.""" | |
| try: | |
| logger.info("Building final report...") | |
| _emit("Reporter", "running", message="Compiling final risk report") | |
| result = await run_reporter(scored_clauses, filename, contract_type, partial, truncation_note) | |
| _emit("Reporter", "completed", message=f"Report ready — score: {result.summary.overall_score}/10") | |
| return result | |
| except Exception as e: | |
| _emit("Reporter", "failed", message=str(e)[:80]) | |
| logger.error("Reporter agent failed: %s", e) | |
| return FinalReport(contract_name=filename, processed_normally=False) | |