Spaces:
Sleeping
Sleeping
| """ | |
| ClauseGuard β World's Best Legal Contract Analysis Tool (v4.3) | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PERF v4.3: | |
| β’ PERF: Upgraded embedder to BAAI/bge-small-en-v1.5 (+21% retrieval accuracy) | |
| β’ PERF: Batched clause classification (single forward pass, batch_size=8) | |
| β’ PERF: ONNX INT8 quantized model support (2-4x faster on CPU) | |
| β’ PERF: torch.set_num_threads(2) to prevent CPU thrashing | |
| β’ NEW: ml/export_onnx_v2.py β full mergeβONNXβquantize pipeline | |
| Fixes in v4.2: | |
| β’ FIX: NLI now uses CrossEncoder.predict() β contradictions actually work | |
| β’ FIX: BoundedCache uses threading.RLock β no more race conditions | |
| β’ FIX: Pre-compiled ALL regex patterns at module level (perf) | |
| β’ FIX: Added missing regex labels to RISK_MAP/DESC_MAP | |
| β’ FIX: Extension risk formula matches backend | |
| β’ FIX: Extension API_BASE URL corrected | |
| β’ FIX: API CORS localhost requires explicit opt-in | |
| Fixes in v4.1: | |
| β’ FIX: Bounded LRU caches (chunk_cache, prediction_cache) β no more memory leaks | |
| β’ FIX: NLI input format β pass (text_a, text_b) tuple, not [SEP]-concatenated string | |
| β’ FIX: Classifier max_length raised to 512 (was 256 β truncating legal clauses) | |
| β’ FIX: Risk score formula β absolute risk, not normalized by total_clauses | |
| β’ FIX: Train/inference alignment β use softmax+argmax for single-label model | |
| β’ FIX: Added missing regex fallback patterns for more CUAD categories | |
| β’ FIX: Entity extraction batching β single pipeline call instead of sequential | |
| β’ PERF: Shared model singleton via models.py module | |
| β’ PERF: LRU-bounded caches everywhere | |
| Carried from v4.0: | |
| β’ OCR support for scanned PDFs (docTR engine with smart native/scanned routing) | |
| β’ Contract Q&A Chatbot (RAG: embedding retrieval + HF Inference API streaming) | |
| β’ Clause Redlining (3-tier: template lookup + RAG + LLM refinement) | |
| β’ Fixed CUAD label mapping (added missing index 6) | |
| β’ Structure-aware clause splitting | |
| β’ Real NLI contradiction detection via cross-encoder model | |
| β’ ML-based Legal NER with regex fallback | |
| β’ Semantic compliance checking with negation handling | |
| β’ Improved obligation extraction with false-positive filtering | |
| β’ LLM-powered clause explanations | |
| β’ Per-session temp files (no collision) | |
| β’ Model health reporting | |
| Models: | |
| β’ Clause classifier: Mokshith31/legalbert-contract-clause-classification | |
| (LoRA adapter on nlpaueb/legal-bert-base-uncased, 41 CUAD classes) | |
| β’ Legal NER: matterstack/legal-bert-ner (token classification) | |
| β’ NLI: cross-encoder/nli-deberta-v3-base (contradiction detection) | |
| β’ Embeddings: sentence-transformers/all-MiniLM-L6-v2 (RAG retrieval) | |
| β’ OCR: docTR fast_base + crnn_vgg16_bn (scanned PDF extraction) | |
| β’ LLM: Qwen/Qwen2.5-7B-Instruct via HF Inference API (chatbot + redlining) | |
| """ | |
| import os | |
| import re | |
| import json | |
| import csv | |
| import io | |
| import uuid | |
| import tempfile | |
| import hashlib | |
| import threading | |
| from collections import defaultdict, OrderedDict | |
| from datetime import datetime | |
| from functools import lru_cache | |
| import gradio as gr | |
| import numpy as np | |
| # ββ Document parsers (soft-fail) ββββββββββββββββββββββββββββββββββββ | |
| try: | |
| import pdfplumber | |
| _HAS_PDF = True | |
| except Exception: | |
| _HAS_PDF = False | |
| try: | |
| from docx import Document as DocxDocument | |
| _HAS_DOCX = True | |
| except Exception: | |
| _HAS_DOCX = False | |
| # ββ PyTorch / Transformers (soft-fail) ββββββββββββββββββββββββββββββββ | |
| _HAS_TORCH = False | |
| _HAS_NER_MODEL = False | |
| _HAS_NLI_MODEL = False | |
| try: | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForSequenceClassification, | |
| AutoModelForTokenClassification, pipeline | |
| ) | |
| from peft import PeftModel | |
| _HAS_TORCH = True | |
| # PERF v4.3: Limit PyTorch threads to avoid CPU thrashing under concurrent requests. | |
| # HF Spaces CPU-basic has 2 vCPUs. Reserve 1 thread for Gradio server. | |
| torch.set_num_threads(2) | |
| torch.set_num_interop_threads(1) | |
| except Exception: | |
| pass | |
| # ββ ONNX Runtime (soft-fail, for quantized model) βββββββββββββββββββββ | |
| _HAS_ORT = False | |
| try: | |
| from optimum.onnxruntime import ORTModelForSequenceClassification as _ORTModel | |
| _HAS_ORT = True | |
| except ImportError: | |
| pass | |
| # ββ CrossEncoder for NLI (soft-fail) ββββββββββββββββββββββββββββββββββ | |
| _HAS_CROSS_ENCODER = False | |
| try: | |
| from sentence_transformers import CrossEncoder as _CrossEncoder | |
| _HAS_CROSS_ENCODER = True | |
| except ImportError: | |
| pass | |
| # ββ Import submodules βββββββββββββββββββββββββββββββββββββββββββββββ | |
| from compare import compare_contracts, render_comparison_html | |
| from obligations import extract_obligations, render_obligations_html | |
| from compliance import check_compliance, render_compliance_html | |
| from ocr_engine import parse_pdf_smart, get_ocr_status | |
| from chatbot import index_contract, chat_respond, get_chatbot_status | |
| from redlining import generate_redlines, render_redlines_html | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. CONFIGURATION β FIXED label mapping (41 labels, index 6 restored) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CUAD_LABELS = [ | |
| "Document Name", # 0 | |
| "Parties", # 1 | |
| "Agreement Date", # 2 | |
| "Effective Date", # 3 | |
| "Expiration Date", # 4 | |
| "Renewal Term", # 5 | |
| "Notice Period to Terminate Renewal", # 6 β WAS MISSING | |
| "Governing Law", # 7 | |
| "Most Favored Nation", # 8 | |
| "Non-Compete", # 9 | |
| "Exclusivity", # 10 | |
| "No-Solicit of Customers", # 11 | |
| "No-Solicit of Employees", # 12 | |
| "Non-Disparagement", # 13 | |
| "Termination for Convenience", # 14 | |
| "ROFR/ROFO/ROFN", # 15 | |
| "Change of Control", # 16 | |
| "Anti-Assignment", # 17 | |
| "Revenue/Profit Sharing", # 18 | |
| "Price Restriction", # 19 | |
| "Minimum Commitment", # 20 | |
| "Volume Restriction", # 21 | |
| "IP Ownership Assignment", # 22 | |
| "Joint IP Ownership", # 23 | |
| "License Grant", # 24 | |
| "Non-Transferable License", # 25 | |
| "Affiliate License-Licensor", # 26 | |
| "Affiliate License-Licensee", # 27 | |
| "Unlimited/All-You-Can-Eat License", # 28 | |
| "Irrevocable or Perpetual License", # 29 | |
| "Source Code Escrow", # 30 | |
| "Post-Termination Services", # 31 | |
| "Audit Rights", # 32 | |
| "Uncapped Liability", # 33 | |
| "Cap on Liability", # 34 | |
| "Liquidated Damages", # 35 | |
| "Warranty Duration", # 36 | |
| "Insurance", # 37 | |
| "Covenant Not to Sue", # 38 | |
| "Third Party Beneficiary", # 39 | |
| "Other", # 40 | |
| ] | |
| _UNFAIR_LABELS = [ | |
| "Limitation of liability", "Unilateral termination", "Unilateral change", | |
| "Content removal", "Contract by using", "Choice of law", | |
| "Jurisdiction", "Arbitration" | |
| ] | |
| # FIX v4.2: Include regex-only labels that aren't in CUAD or Unfair lists | |
| _EXTRA_REGEX_LABELS = [ | |
| "Indemnification", "Confidentiality", "Force Majeure", "Penalties" | |
| ] | |
| _ALL_LABELS = CUAD_LABELS + _UNFAIR_LABELS + _EXTRA_REGEX_LABELS | |
| RISK_MAP = { | |
| # Critical | |
| "Uncapped Liability": "CRITICAL", | |
| "Arbitration": "CRITICAL", | |
| "IP Ownership Assignment": "CRITICAL", | |
| "Termination for Convenience": "CRITICAL", | |
| "Limitation of liability": "CRITICAL", | |
| "Unilateral termination": "CRITICAL", | |
| "Liquidated Damages": "CRITICAL", | |
| # High | |
| "Non-Compete": "HIGH", | |
| "Exclusivity": "HIGH", | |
| "Change of Control": "HIGH", | |
| "No-Solicit of Customers": "HIGH", | |
| "No-Solicit of Employees": "HIGH", | |
| "Unilateral change": "HIGH", | |
| "Content removal": "HIGH", | |
| "Anti-Assignment": "HIGH", | |
| "Notice Period to Terminate Renewal": "HIGH", | |
| # Medium | |
| "Governing Law": "MEDIUM", | |
| "Jurisdiction": "MEDIUM", | |
| "Choice of law": "MEDIUM", | |
| "Price Restriction": "MEDIUM", | |
| "Minimum Commitment": "MEDIUM", | |
| "Volume Restriction": "MEDIUM", | |
| "Non-Disparagement": "MEDIUM", | |
| "Most Favored Nation": "MEDIUM", | |
| "Revenue/Profit Sharing": "MEDIUM", | |
| "Warranty Duration": "MEDIUM", | |
| # Low | |
| "Document Name": "LOW", | |
| "Parties": "LOW", | |
| "Agreement Date": "LOW", | |
| "Effective Date": "LOW", | |
| "Expiration Date": "LOW", | |
| "Renewal Term": "LOW", | |
| "Joint IP Ownership": "LOW", | |
| "License Grant": "LOW", | |
| "Non-Transferable License": "LOW", | |
| "Affiliate License-Licensor": "LOW", | |
| "Affiliate License-Licensee": "LOW", | |
| "Unlimited/All-You-Can-Eat License": "LOW", | |
| "Irrevocable or Perpetual License": "LOW", | |
| "Source Code Escrow": "LOW", | |
| "Post-Termination Services": "LOW", | |
| "Audit Rights": "LOW", | |
| "Cap on Liability": "LOW", | |
| "Insurance": "LOW", | |
| "Covenant Not to Sue": "LOW", | |
| "Third Party Beneficiary": "LOW", | |
| "Other": "LOW", | |
| "ROFR/ROFO/ROFN": "LOW", | |
| "Contract by using": "LOW", | |
| # FIX v4.2: Added regex-only labels that were missing from RISK_MAP | |
| "Indemnification": "HIGH", | |
| "Confidentiality": "MEDIUM", | |
| "Force Majeure": "LOW", | |
| "Penalties": "HIGH", | |
| } | |
| DESC_MAP = {label: label.replace("_", " ") for label in _ALL_LABELS} | |
| DESC_MAP.update({ | |
| "Limitation of liability": "Company limits or excludes liability for losses, data breaches, or service failures.", | |
| "Unilateral termination": "Company can terminate your account at any time without reason.", | |
| "Unilateral change": "Company can change terms at any time without your consent.", | |
| "Content removal": "Company can delete your content without notice or justification.", | |
| "Contract by using": "You are bound to the contract simply by using the service.", | |
| "Choice of law": "Governing law may differ from your country, reducing your legal protections.", | |
| "Jurisdiction": "Disputes must be resolved in a jurisdiction that may disadvantage you.", | |
| "Arbitration": "Forces disputes to arbitration instead of court. You waive your right to sue.", | |
| "Uncapped Liability": "No financial limit on damages the party may be liable for.", | |
| "Cap on Liability": "Maximum financial liability is explicitly capped.", | |
| "Non-Compete": "Restrictions on competing with the counter-party.", | |
| "Exclusivity": "Obligation to deal exclusively with one party.", | |
| "IP Ownership Assignment": "Intellectual property rights are transferred entirely.", | |
| "Termination for Convenience": "Either party may terminate without cause or notice.", | |
| "Governing Law": "Specifies which jurisdiction's laws apply.", | |
| "Non-Disparagement": "Agreement not to speak negatively about the other party.", | |
| "ROFR/ROFO/ROFN": "Right of First Refusal / Offer / Negotiation clause.", | |
| "Change of Control": "Provisions triggered by ownership or control changes.", | |
| "Anti-Assignment": "Restrictions on transferring contract rights to third parties.", | |
| "Liquidated Damages": "Pre-determined damages amount for breach of contract.", | |
| "Source Code Escrow": "Third-party holds source code for release under defined conditions.", | |
| "Post-Termination Services": "Services to be provided after the contract ends.", | |
| "Audit Rights": "Right to inspect records or verify compliance.", | |
| "Warranty Duration": "Length of time warranties remain in effect.", | |
| "Covenant Not to Sue": "Agreement not to bring legal action against a party.", | |
| "Third Party Beneficiary": "Non-party who benefits from the contract terms.", | |
| "Insurance": "Insurance coverage requirements.", | |
| "Revenue/Profit Sharing": "Revenue or profit sharing arrangements between parties.", | |
| "Price Restriction": "Restrictions on pricing or discounting.", | |
| "Minimum Commitment": "Minimum purchase or usage commitment.", | |
| "Volume Restriction": "Limits on volume of goods or services.", | |
| "License Grant": "Permission to use intellectual property.", | |
| "Non-Transferable License": "License that cannot be transferred to third parties.", | |
| "Irrevocable or Perpetual License": "License that cannot be revoked or lasts indefinitely.", | |
| "Unlimited/All-You-Can-Eat License": "License with no usage limits.", | |
| "Notice Period to Terminate Renewal": "Required notice period before automatic renewal.", | |
| # FIX v4.2: Added descriptions for regex-only labels | |
| "Indemnification": "Obligation to compensate the other party for losses or damages.", | |
| "Confidentiality": "Restrictions on sharing proprietary or sensitive information.", | |
| "Force Majeure": "Excuses performance due to extraordinary events beyond control.", | |
| "Penalties": "Financial penalties for breach or late performance.", | |
| }) | |
| RISK_WEIGHTS = {"CRITICAL": 40, "HIGH": 20, "MEDIUM": 10, "LOW": 3} | |
| RISK_STYLES = { | |
| "CRITICAL": ("#dc2626", "#fef2f2", "β οΈ"), | |
| "HIGH": ("#ea580c", "#fff7ed", "β‘"), | |
| "MEDIUM": ("#ca8a04", "#fefce8", "π"), | |
| "LOW": ("#16a34a", "#f0fdf4", "β"), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FIX v4.1: Per-class thresholds aligned with single-label softmax | |
| # The model was trained with cross-entropy (single-label), so inference | |
| # now uses softmax+argmax, not sigmoid. Thresholds apply to softmax probs. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _CUAD_THRESHOLDS = {} | |
| _WEAK_CLASSES = {0, 1, 2, 7, 9, 21, 22, 27, 37, 38} | |
| for _i in range(41): | |
| if _i in _WEAK_CLASSES: | |
| _CUAD_THRESHOLDS[_i] = 0.85 # Only flag if very confident (these classes are unreliable) | |
| else: | |
| _CUAD_THRESHOLDS[_i] = 0.40 # Reasonable threshold for softmax outputs | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FIX v4.1: Bounded LRU Cache utility (replaces unbounded dicts) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class BoundedCache: | |
| """Thread-safe bounded LRU cache using OrderedDict + RLock. | |
| FIX v4.2: Added threading.RLock to prevent race conditions under | |
| Gradio's concurrent request handling. OrderedDict compound operations | |
| (contains + setitem + move_to_end + popitem) are NOT atomic even with GIL.""" | |
| def __init__(self, maxsize=1000): | |
| self._cache = OrderedDict() | |
| self._maxsize = maxsize | |
| self._lock = threading.RLock() | |
| def get(self, key, default=None): | |
| with self._lock: | |
| if key in self._cache: | |
| self._cache.move_to_end(key) | |
| return self._cache[key] | |
| return default | |
| def put(self, key, value): | |
| with self._lock: | |
| if key in self._cache: | |
| self._cache.move_to_end(key) | |
| self._cache[key] = value | |
| else: | |
| if len(self._cache) >= self._maxsize: | |
| self._cache.popitem(last=False) | |
| self._cache[key] = value | |
| def __contains__(self, key): | |
| with self._lock: | |
| return key in self._cache | |
| def __len__(self): | |
| with self._lock: | |
| return len(self._cache) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. MODEL LOADING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| cuad_tokenizer = None | |
| cuad_model = None | |
| ner_pipeline = None | |
| nli_model = None # FIX v4.2: CrossEncoder instead of pipeline | |
| _model_status = {"cuad": "not_loaded", "ner": "not_loaded", "nli": "not_loaded"} | |
| def _load_cuad_model(): | |
| global cuad_tokenizer, cuad_model, _model_status | |
| # PERF v4.3: Try ONNX quantized model first (2-4x faster on CPU) | |
| onnx_model_path = os.environ.get("ONNX_MODEL_PATH", "") | |
| onnx_hub_id = os.environ.get("ONNX_HUB_MODEL_ID", "gaurv007/clauseguard-onnx-int8") | |
| if _HAS_ORT: | |
| for source in [onnx_model_path, onnx_hub_id]: | |
| if not source: | |
| continue | |
| try: | |
| print(f"[ClauseGuard] Trying ONNX model: {source}") | |
| cuad_model = _ORTModel.from_pretrained(source, file_name="model_quantized.onnx") | |
| cuad_tokenizer = AutoTokenizer.from_pretrained(source) | |
| _model_status["cuad"] = "loaded (ONNX INT8)" | |
| print(f"[ClauseGuard] ONNX INT8 model loaded from {source}") | |
| return | |
| except Exception as e: | |
| print(f"[ClauseGuard] ONNX load failed from {source}: {e}") | |
| # Fallback to PyTorch PEFT model | |
| if not _HAS_TORCH: | |
| print("[ClauseGuard] PyTorch not available β using regex fallback") | |
| _model_status["cuad"] = "unavailable" | |
| return | |
| try: | |
| base = "nlpaueb/legal-bert-base-uncased" | |
| adapter = "Mokshith31/legalbert-contract-clause-classification" | |
| print(f"[ClauseGuard] Loading CUAD classifier (PyTorch): {adapter}") | |
| cuad_tokenizer = AutoTokenizer.from_pretrained(base) | |
| base_model = AutoModelForSequenceClassification.from_pretrained( | |
| base, num_labels=41, ignore_mismatched_sizes=True | |
| ) | |
| cuad_model = PeftModel.from_pretrained(base_model, adapter) | |
| cuad_model.eval() | |
| _model_status["cuad"] = "loaded (PyTorch)" | |
| print("[ClauseGuard] CUAD model loaded successfully (PyTorch)") | |
| except Exception as e: | |
| print(f"[ClauseGuard] CUAD model load failed: {e}") | |
| cuad_tokenizer = None | |
| cuad_model = None | |
| _model_status["cuad"] = f"failed: {e}" | |
| def _load_ner_model(): | |
| global ner_pipeline, _model_status, _HAS_NER_MODEL | |
| if not _HAS_TORCH: | |
| _model_status["ner"] = "unavailable" | |
| return | |
| try: | |
| print("[ClauseGuard] Loading Legal NER model: matterstack/legal-bert-ner") | |
| ner_pipeline = pipeline( | |
| "ner", | |
| model="matterstack/legal-bert-ner", | |
| aggregation_strategy="simple", | |
| device=-1, # CPU | |
| ) | |
| _HAS_NER_MODEL = True | |
| _model_status["ner"] = "loaded" | |
| print("[ClauseGuard] Legal NER model loaded successfully") | |
| except Exception as e: | |
| print(f"[ClauseGuard] Legal NER model load failed (using regex fallback): {e}") | |
| _model_status["ner"] = f"failed: {e}" | |
| def _load_nli_model(): | |
| global nli_model, _model_status, _HAS_NLI_MODEL | |
| if not _HAS_CROSS_ENCODER: | |
| _model_status["nli"] = "unavailable (sentence-transformers not installed)" | |
| return | |
| try: | |
| print("[ClauseGuard] Loading NLI model: cross-encoder/nli-deberta-v3-base (CrossEncoder)") | |
| nli_model = _CrossEncoder("cross-encoder/nli-deberta-v3-base") | |
| _HAS_NLI_MODEL = True | |
| _model_status["nli"] = "loaded" | |
| print("[ClauseGuard] NLI CrossEncoder loaded successfully") | |
| except Exception as e: | |
| print(f"[ClauseGuard] NLI model load failed (using heuristic fallback): {e}") | |
| _model_status["nli"] = f"failed: {e}" | |
| def get_model_status_text(): | |
| """Return human-readable model status.""" | |
| parts = [] | |
| for name, status in _model_status.items(): | |
| icon = "β " if status == "loaded" else "β οΈ" if "failed" in status else "β" | |
| label = {"cuad": "Clause Classifier", "ner": "Legal NER", "nli": "NLI Contradiction"}[name] | |
| parts.append(f"{icon} {label}: {status}") | |
| return " Β· ".join(parts) | |
| # Load models at startup | |
| _load_cuad_model() | |
| _load_ner_model() | |
| _load_nli_model() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. DOCUMENT PARSING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def parse_pdf(file_path): | |
| """Smart PDF parser: native text extraction with OCR fallback for scanned PDFs.""" | |
| text, error, method = parse_pdf_smart(file_path) | |
| if text: | |
| if method == "ocr": | |
| print(f"[ClauseGuard] PDF extracted via OCR ({len(text)} chars)") | |
| return text, None | |
| if error: | |
| return None, error | |
| return None, "Could not extract text from PDF. Try uploading a clearer scan or digital PDF." | |
| def parse_docx(file_path): | |
| if not _HAS_DOCX: | |
| return None, "DOCX parsing not available (python-docx not installed)" | |
| try: | |
| doc = DocxDocument(file_path) | |
| paragraphs = [p.text for p in doc.paragraphs if p.text.strip()] | |
| return "\n\n".join(paragraphs), None | |
| except Exception as e: | |
| return None, f"DOCX parse error: {e}" | |
| def parse_document(file_path): | |
| if file_path is None: | |
| return None, "No file uploaded" | |
| ext = os.path.splitext(file_path)[1].lower() | |
| if ext == ".pdf": | |
| return parse_pdf(file_path) | |
| elif ext in (".docx", ".doc"): | |
| return parse_docx(file_path) | |
| elif ext in (".txt", ".md", ".rst"): | |
| try: | |
| with open(file_path, "r", encoding="utf-8", errors="ignore") as f: | |
| return f.read(), None | |
| except Exception as e: | |
| return None, f"Text read error: {e}" | |
| else: | |
| return None, f"Unsupported file type: {ext}" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. DETERMINISTIC CLAUSE SPLITTING | |
| # FIX v4.1: Bounded cache (max 500 documents) instead of unbounded dict | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _chunk_cache = BoundedCache(maxsize=500) | |
| # FIX v4.2: Pre-compile section pattern at module level (was recompiling per call) | |
| _SECTION_PATTERN = re.compile( | |
| r'(?:^|\n\n)' | |
| r'(?=' | |
| r'\d+(?:\.\d+)*[.)]\s' # 1. 2. 3.1. 3.1) | |
| r'|[A-Z]{2,}[A-Z\s]*\n' # ALL CAPS HEADERS | |
| r'|\([a-z]\)\s' # (a) (b) (c) | |
| r'|(?:Section|Article|Clause)\s+\d+' # Section 1, Article 2 | |
| r')', | |
| re.MULTILINE | |
| ) | |
| def split_clauses(text): | |
| """Deterministic, structure-aware clause splitting. | |
| Same input ALWAYS produces same output. Normalized text is hashed | |
| and cached so repeated runs on identical documents are identical.""" | |
| normalized = re.sub(r'\s+', ' ', text.strip()) | |
| text_hash = hashlib.sha256(normalized.encode()).hexdigest() | |
| cached = _chunk_cache.get(text_hash) | |
| if cached is not None: | |
| return cached | |
| text = re.sub(r'\n{3,}', '\n\n', text.strip()) | |
| # First try to detect numbered sections (1., 2., 3.1, (a), etc.) | |
| positions = [m.start() for m in _SECTION_PATTERN.finditer(text)] | |
| if len(positions) >= 3: | |
| clauses = [] | |
| for i, pos in enumerate(positions): | |
| end = positions[i + 1] if i + 1 < len(positions) else len(text) | |
| chunk = text[pos:end].strip() | |
| if len(chunk) > 30: | |
| if len(chunk) > 1500: | |
| sub_parts = chunk.split('\n\n') | |
| current = "" | |
| for sp in sub_parts: | |
| if len(current) + len(sp) < 1200: | |
| current += ("\n\n" + sp if current else sp) | |
| else: | |
| if len(current.strip()) > 30: | |
| clauses.append(current.strip()) | |
| current = sp | |
| if len(current.strip()) > 30: | |
| clauses.append(current.strip()) | |
| else: | |
| clauses.append(chunk) | |
| if positions and positions[0] > 50: | |
| preamble = text[:positions[0]].strip() | |
| if len(preamble) > 30: | |
| clauses.insert(0, preamble) | |
| result = clauses if clauses else _fallback_split(text) | |
| _chunk_cache.put(text_hash, result) | |
| return result | |
| else: | |
| result = _fallback_split(text) | |
| _chunk_cache.put(text_hash, result) | |
| return result | |
| def _fallback_split(text): | |
| """Fallback: split on paragraph breaks and sentence boundaries.""" | |
| paragraphs = text.split('\n\n') | |
| if len(paragraphs) >= 3: | |
| clauses = [] | |
| for p in paragraphs: | |
| p = p.strip() | |
| if len(p) > 30: | |
| if len(p) > 1500: | |
| sents = re.split(r'(?<=[.!?])\s+(?=[A-Z])', p) | |
| current = "" | |
| for s in sents: | |
| if len(current) + len(s) < 1000: | |
| current += (" " + s if current else s) | |
| else: | |
| if len(current.strip()) > 30: | |
| clauses.append(current.strip()) | |
| current = s | |
| if len(current.strip()) > 30: | |
| clauses.append(current.strip()) | |
| else: | |
| clauses.append(p) | |
| return clauses | |
| parts = re.split(r'(?<=[.!?])\s+(?=[A-Z0-9(])', text) | |
| return [p.strip() for p in parts if len(p.strip()) > 30] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. CLAUSE DETECTION | |
| # FIX v4.1: Use softmax (matching training) instead of sigmoid | |
| # FIX v4.1: max_length raised to 512 (was 256) | |
| # FIX v4.1: Bounded prediction cache | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _HEADING_RE = re.compile(r'^\d+(?:\.\d+)*\s+[A-Z][A-Z\s&,/]+$', re.MULTILINE) | |
| def _strip_heading(text): | |
| """Remove leading section headings that confuse the classifier.""" | |
| lines = text.split('\n') | |
| if lines and _HEADING_RE.match(lines[0].strip()): | |
| stripped = '\n'.join(lines[1:]).strip() | |
| return stripped if len(stripped) > 20 else text | |
| return text | |
| _LABEL_GUARDRAILS = { | |
| "Liquidated Damages": re.compile( | |
| r'liquidated|pre-?determined.{0,10}damage|agreed.{0,10}sum|penalty clause|stipulated.{0,10}damage', | |
| re.IGNORECASE | |
| ), | |
| "Uncapped Liability": re.compile( | |
| r'uncapped|unlimited.{0,10}liabilit|no.{0,10}(limit|cap).{0,10}liabilit', | |
| re.IGNORECASE | |
| ), | |
| } | |
| def _apply_guardrails(label, text, confidence): | |
| guard = _LABEL_GUARDRAILS.get(label) | |
| if guard and not guard.search(text): | |
| return "Other", confidence * 0.3 | |
| return label, confidence | |
| def _text_hash(text): | |
| return hashlib.md5(text.encode()).hexdigest() | |
| # FIX v4.1: Bounded prediction cache | |
| _prediction_cache = BoundedCache(maxsize=2000) | |
| def classify_cuad(clause_text): | |
| if cuad_model is None or cuad_tokenizer is None: | |
| return _classify_regex(clause_text) | |
| clean_text = _strip_heading(clause_text) | |
| h = _text_hash(clean_text[:512]) | |
| cached = _prediction_cache.get(h) | |
| if cached is not None: | |
| return cached | |
| try: | |
| # FIX v4.1: max_length=512 (was 256 β truncating long legal clauses) | |
| inputs = cuad_tokenizer( | |
| clean_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ) | |
| with torch.no_grad(): | |
| logits = cuad_model(**inputs).logits | |
| # FIX v4.1: Use softmax (matching single-label cross-entropy training) | |
| # The model was trained with F.cross_entropy, so softmax is correct. | |
| probs = torch.softmax(logits, dim=-1)[0] | |
| # Get the top prediction | |
| top_prob, top_idx = torch.max(probs, dim=0) | |
| top_idx = int(top_idx) | |
| top_conf = float(top_prob) | |
| results = [] | |
| # Primary prediction | |
| threshold = _CUAD_THRESHOLDS.get(top_idx, 0.40) | |
| if top_conf > threshold and top_idx < len(CUAD_LABELS): | |
| label = CUAD_LABELS[top_idx] | |
| conf = top_conf | |
| label, conf = _apply_guardrails(label, clause_text, conf) | |
| if not (label == "Other" and conf < 0.3): | |
| risk = RISK_MAP.get(label, "LOW") | |
| results.append({ | |
| "label": label, | |
| "confidence": round(conf, 3), | |
| "risk": risk, | |
| "description": DESC_MAP.get(label, label), | |
| "source": "ml", | |
| }) | |
| # Also check 2nd-best prediction if confident enough | |
| if len(probs) > 1: | |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
| if len(sorted_probs) > 1: | |
| second_idx = int(sorted_indices[1]) | |
| second_conf = float(sorted_probs[1]) | |
| second_threshold = _CUAD_THRESHOLDS.get(second_idx, 0.40) | |
| if second_conf > second_threshold and second_idx < len(CUAD_LABELS): | |
| label2 = CUAD_LABELS[second_idx] | |
| conf2 = second_conf | |
| label2, conf2 = _apply_guardrails(label2, clause_text, conf2) | |
| if not (label2 == "Other" and conf2 < 0.3): | |
| # Only add if different from primary | |
| if not results or results[0]["label"] != label2: | |
| risk2 = RISK_MAP.get(label2, "LOW") | |
| results.append({ | |
| "label": label2, | |
| "confidence": round(conf2, 3), | |
| "risk": risk2, | |
| "description": DESC_MAP.get(label2, label2), | |
| "source": "ml", | |
| }) | |
| results.sort(key=lambda x: x["confidence"], reverse=True) | |
| # If no ML results, also try regex to catch what model misses | |
| if not results: | |
| results = _classify_regex(clause_text) | |
| _prediction_cache.put(h, results) | |
| return results | |
| except Exception as e: | |
| print(f"[ClauseGuard] CUAD inference error: {e}") | |
| return _classify_regex(clause_text) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5b. BATCHED CLAUSE CLASSIFICATION | |
| # PERF v4.3: Single forward pass for all clauses instead of one-by-one | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def classify_cuad_batch(clauses, batch_size=8): | |
| """Classify a batch of clauses in a single forward pass. | |
| PERF v4.3: Replaces sequential classify_cuad() loop. | |
| On CPU, batch_size=8 balances memory vs throughput.""" | |
| if cuad_model is None or cuad_tokenizer is None: | |
| # Fallback to regex for all clauses | |
| return [_classify_regex(c) for c in clauses] | |
| all_results = [] | |
| # Check cache first, collect uncached clauses | |
| uncached_indices = [] | |
| uncached_texts = [] | |
| for i, clause in enumerate(clauses): | |
| clean = _strip_heading(clause) | |
| h = _text_hash(clean[:512]) | |
| cached = _prediction_cache.get(h) | |
| if cached is not None: | |
| all_results.append((i, cached)) | |
| else: | |
| uncached_indices.append(i) | |
| uncached_texts.append(clean) | |
| all_results.append((i, None)) # placeholder | |
| if not uncached_texts: | |
| return [r for _, r in sorted(all_results)] | |
| # Process uncached in batches | |
| for batch_start in range(0, len(uncached_texts), batch_size): | |
| batch_texts = uncached_texts[batch_start:batch_start + batch_size] | |
| batch_original = [clauses[uncached_indices[batch_start + j]] for j in range(len(batch_texts))] | |
| try: | |
| inputs = cuad_tokenizer( | |
| batch_texts, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True, | |
| ) | |
| with torch.no_grad(): | |
| logits = cuad_model(**inputs).logits | |
| probs = torch.softmax(logits, dim=-1) | |
| for j in range(len(batch_texts)): | |
| clause_probs = probs[j] | |
| original_text = batch_original[j] | |
| results = [] | |
| # Primary prediction | |
| top_prob, top_idx = torch.max(clause_probs, dim=0) | |
| top_idx_int = int(top_idx) | |
| top_conf = float(top_prob) | |
| threshold = _CUAD_THRESHOLDS.get(top_idx_int, 0.40) | |
| if top_conf > threshold and top_idx_int < len(CUAD_LABELS): | |
| label = CUAD_LABELS[top_idx_int] | |
| conf = top_conf | |
| label, conf = _apply_guardrails(label, original_text, conf) | |
| if not (label == "Other" and conf < 0.3): | |
| risk = RISK_MAP.get(label, "LOW") | |
| results.append({ | |
| "label": label, | |
| "confidence": round(conf, 3), | |
| "risk": risk, | |
| "description": DESC_MAP.get(label, label), | |
| "source": "ml", | |
| }) | |
| # 2nd-best prediction | |
| sorted_probs, sorted_indices = torch.sort(clause_probs, descending=True) | |
| if len(sorted_probs) > 1: | |
| second_idx = int(sorted_indices[1]) | |
| second_conf = float(sorted_probs[1]) | |
| second_threshold = _CUAD_THRESHOLDS.get(second_idx, 0.40) | |
| if second_conf > second_threshold and second_idx < len(CUAD_LABELS): | |
| label2 = CUAD_LABELS[second_idx] | |
| conf2 = second_conf | |
| label2, conf2 = _apply_guardrails(label2, original_text, conf2) | |
| if not (label2 == "Other" and conf2 < 0.3): | |
| if not results or results[0]["label"] != label2: | |
| risk2 = RISK_MAP.get(label2, "LOW") | |
| results.append({ | |
| "label": label2, | |
| "confidence": round(conf2, 3), | |
| "risk": risk2, | |
| "description": DESC_MAP.get(label2, label2), | |
| "source": "ml", | |
| }) | |
| results.sort(key=lambda x: x["confidence"], reverse=True) | |
| if not results: | |
| results = _classify_regex(original_text) | |
| # Cache the result | |
| h = _text_hash(batch_texts[j][:512]) | |
| _prediction_cache.put(h, results) | |
| # Update placeholder in all_results | |
| global_idx = uncached_indices[batch_start + j] | |
| for k, (idx, _) in enumerate(all_results): | |
| if idx == global_idx: | |
| all_results[k] = (idx, results) | |
| break | |
| except Exception as e: | |
| print(f"[ClauseGuard] Batch CUAD inference error: {e}") | |
| # Fallback to regex for this batch | |
| for j in range(len(batch_texts)): | |
| global_idx = uncached_indices[batch_start + j] | |
| results = _classify_regex(batch_original[j]) | |
| for k, (idx, _) in enumerate(all_results): | |
| if idx == global_idx: | |
| all_results[k] = (idx, results) | |
| break | |
| return [r for _, r in sorted(all_results)] | |
| # FIX v4.1: Extended regex patterns to cover more CUAD categories | |
| _REGEX_PATTERNS = { | |
| "Limitation of liability": [r"not liable", r"shall not be (liable|responsible)", r"in no event.*liable", r"limitation of liability", r"without warranty", r"disclaim"], | |
| "Unilateral termination": [r"terminat.*at any time", r"suspend.*account.*without", r"we may (terminat|suspend|discontinu)", r"right to (terminat|suspend)"], | |
| "Unilateral change": [r"sole discretion", r"reserves? the right to (modify|change|update|amend)", r"at any time.*without (prior )?notice", r"we may (modify|change|update)"], | |
| "Content removal": [r"remove.*content.*without", r"right to remove", r"we may.*remove"], | |
| "Contract by using": [r"by (using|accessing).*you agree", r"continued use.*constitutes? acceptance"], | |
| "Choice of law": [r"governed by.*laws? of", r"shall be governed", r"laws of the state of"], | |
| "Jurisdiction": [r"exclusive jurisdiction", r"courts? of.*(california|delaware|new york|ireland|england)", r"submit to.*jurisdiction"], | |
| "Arbitration": [r"arbitrat", r"binding arbitration", r"waive.*right.*court", r"class action waiver"], | |
| "Governing Law": [r"governed by", r"laws of", r"jurisdiction of"], | |
| "Termination for Convenience": [r"terminat.*for convenience", r"terminat.*without cause", r"terminat.*at any time"], | |
| "Non-Compete": [r"non-compete", r"shall not compete", r"competition restriction"], | |
| "Exclusivity": [r"exclusive(?:ly)?(?:\s+(?:deal|relationship|partner|right))", r"exclusivity"], | |
| "IP Ownership Assignment": [r"assign.*intellectual property", r"ownership of.*ip", r"all rights.*assign", r"work.?for.?hire"], | |
| "Uncapped Liability": [r"unlimited liability", r"uncapped", r"no.*limit.*liability"], | |
| "Cap on Liability": [r"cap on liability", r"maximum liability", r"liability.*shall not exceed", r"aggregate liability.*not exceed"], | |
| "Indemnification": [r"indemnif", r"hold harmless", r"defend.*against.*claim"], | |
| "Confidentiality": [r"confidential(?:ity)?", r"non-disclosure", r"\bnda\b"], | |
| "Force Majeure": [r"force majeure", r"act of god", r"beyond.*(?:reasonable\s+)?control"], | |
| "Penalties": [r"penalt(?:y|ies)", r"late fee", r"default charge", r"interest on overdue"], | |
| # FIX v4.1: Added missing regex patterns for more CUAD categories | |
| "Audit Rights": [r"audit rights?", r"right to audit", r"inspect.*records?", r"examination of.*records?", r"access to.*books"], | |
| "Warranty Duration": [r"warrant(?:y|ies).*(?:period|duration|term|months?|years?)", r"warranty.*shall.*(?:remain|last|continue)", r"limited warranty"], | |
| "Insurance": [r"(?:shall|must).*maintain.*insurance", r"insurance.*coverage", r"policy of insurance", r"certificate of insurance"], | |
| "Source Code Escrow": [r"source code escrow", r"escrow.*source code", r"escrow agent"], | |
| "Post-Termination Services": [r"post.?termination.*(?:service|obligation|support)", r"(?:after|following|upon).*termination.*(?:shall|must|will).*(?:provide|continue)"], | |
| "Renewal Term": [r"renew(?:al)?.*term", r"auto(?:matic(?:ally)?)?.*renew", r"successive.*(?:term|period)"], | |
| "Notice Period to Terminate Renewal": [r"notice.*(?:to\s+)?terminat.*renew", r"(?:days?|months?).*(?:prior|advance).*(?:notice|written).*(?:terminat|renew)", r"notice of non.?renewal"], | |
| "Change of Control": [r"change of control", r"change in.*(?:ownership|control)", r"merger.*acquisition", r"sale of.*(?:all|substantially).*assets"], | |
| "Anti-Assignment": [r"(?:shall|may)\s+not\s+assign", r"anti.?assignment", r"no.*assignment.*without.*consent"], | |
| "Revenue/Profit Sharing": [r"revenue.*shar", r"profit.*shar", r"royalt(?:y|ies)"], | |
| "Liquidated Damages": [r"liquidated.*damages?", r"pre.?determined.*damage", r"stipulated.*damage"], | |
| "Covenant Not to Sue": [r"covenant not to sue", r"(?:shall|agree).*not.*(?:bring|file|commence).*(?:action|claim|suit)"], | |
| "Joint IP Ownership": [r"joint(?:ly)?.*own(?:ed|ership)?.*(?:ip|intellectual property)", r"co.?own(?:ed|ership)?"], | |
| "License Grant": [r"(?:grant|license).*(?:non.?exclusive|exclusive|perpetual|irrevocable).*(?:license|right)", r"hereby grants?.*license"], | |
| "Non-Transferable License": [r"non.?transferable.*license", r"license.*(?:shall|may)\s+not.*(?:transfer|assign|sublicense)"], | |
| "ROFR/ROFO/ROFN": [r"right of first.*(?:refusal|offer|negotiation)", r"ROFR", r"ROFO", r"ROFN"], | |
| "No-Solicit of Customers": [r"(?:shall|must|agree).*not.*solicit.*customer", r"no.?solicit.*customer", r"non.?solicitation.*customer"], | |
| "No-Solicit of Employees": [r"(?:shall|must|agree).*not.*solicit.*employee", r"no.?solicit.*employee", r"non.?solicitation.*employee", r"no.?hire"], | |
| "Non-Disparagement": [r"non.?disparagement", r"(?:shall|must|agree).*not.*(?:disparag|defam|make.*negative)", r"not.*make.*derogatory"], | |
| "Most Favored Nation": [r"most favou?red.*nation", r"MFN", r"most favou?red.*(?:customer|pricing|terms)"], | |
| "Third Party Beneficiary": [r"third.?party.*beneficiar", r"no.*third.?party.*beneficiar"], | |
| "Minimum Commitment": [r"minimum.*(?:commitment|purchase|order|volume|spend)", r"(?:shall|must).*(?:purchase|order).*(?:at least|minimum|no less than)"], | |
| "Volume Restriction": [r"volume.*(?:restriction|limitation|cap|ceiling)", r"(?:shall|may).*not.*exceed.*(?:volume|quantity)"], | |
| "Price Restriction": [r"price.*(?:restriction|limitation|ceiling|cap|floor)", r"(?:shall|may).*not.*(?:increase|raise|exceed).*price"], | |
| } | |
| # FIX v4.2: Pre-compile regex patterns at module level (was recompiling per call) | |
| _REGEX_PATTERNS_COMPILED = {} | |
| for _label, _pats in _REGEX_PATTERNS.items(): | |
| _REGEX_PATTERNS_COMPILED[_label] = [re.compile(p, re.IGNORECASE) for p in _pats] | |
| def _classify_regex(text): | |
| """Regex fallback β returns pattern match, NOT fake confidence.""" | |
| text_lower = text.lower() | |
| results = [] | |
| seen = set() | |
| for label, patterns in _REGEX_PATTERNS_COMPILED.items(): | |
| for pat in patterns: | |
| if pat.search(text_lower): | |
| if label not in seen: | |
| risk = RISK_MAP.get(label, "MEDIUM") | |
| results.append({ | |
| "label": label, | |
| "confidence": None, | |
| "risk": risk, | |
| "description": DESC_MAP.get(label, label), | |
| "source": "pattern", | |
| }) | |
| seen.add(label) | |
| break | |
| return results | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 6. LEGAL NER β ML model with regex fallback | |
| # FIX v4.1: Batch all chunks in single pipeline call | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_entities(text): | |
| """Extract entities using ML model (matterstack/legal-bert-ner) with regex fallback.""" | |
| entities = [] | |
| if _HAS_NER_MODEL and ner_pipeline is not None: | |
| try: | |
| # FIX v4.1: Create overlapping chunks but batch them in a SINGLE pipeline call | |
| max_text = min(len(text), 10000) | |
| chunks = [text[i:i+512] for i in range(0, max_text, 450)] | |
| offsets = list(range(0, max_text, 450)) | |
| # Single batched pipeline call instead of sequential | |
| all_ner_results = ner_pipeline(chunks, batch_size=8) | |
| for chunk_idx, ner_results in enumerate(all_ner_results): | |
| offset = offsets[chunk_idx] | |
| for ent in ner_results: | |
| if ent.get("score", 0) > 0.5: | |
| entities.append({ | |
| "text": ent["word"], | |
| "type": _map_ner_label(ent.get("entity_group", ent.get("entity", "MISC"))), | |
| "start": ent["start"] + offset, | |
| "end": ent["end"] + offset, | |
| "score": round(ent["score"], 3), | |
| "source": "ml", | |
| }) | |
| except Exception as e: | |
| print(f"[ClauseGuard] ML NER error, falling back to regex: {e}") | |
| entities = _extract_entities_regex(text) | |
| else: | |
| entities = _extract_entities_regex(text) | |
| # Always supplement with regex patterns for things NER often misses | |
| regex_ents = _extract_entities_regex(text) | |
| ml_spans = set() | |
| for e in entities: | |
| for pos in range(e["start"], e["end"]): | |
| ml_spans.add(pos) | |
| for re_ent in regex_ents: | |
| if not any(pos in ml_spans for pos in range(re_ent["start"], re_ent["end"])): | |
| entities.append(re_ent) | |
| # Deduplicate and sort | |
| entities.sort(key=lambda x: (x["start"], -(x["end"] - x["start"]))) | |
| filtered = [] | |
| last_end = -1 | |
| for e in entities: | |
| if e["start"] >= last_end: | |
| filtered.append(e) | |
| last_end = e["end"] | |
| return filtered | |
| def _map_ner_label(label): | |
| label = label.upper() | |
| mapping = { | |
| "PER": "PERSON", "PERSON": "PERSON", | |
| "ORG": "PARTY", "ORGANIZATION": "PARTY", | |
| "LOC": "JURISDICTION", "LOCATION": "JURISDICTION", | |
| "GPE": "JURISDICTION", "DATE": "DATE", | |
| "MONEY": "MONEY", "MISC": "MISC", "LAW": "LEGAL_REF", | |
| } | |
| return mapping.get(label, label) | |
| def _extract_entities_regex(text): | |
| """Regex-based NER fallback.""" | |
| entities = [] | |
| patterns = [ | |
| (r'\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},?\s+\d{4}\b', "DATE"), | |
| (r'\b\d{1,2}/\d{1,2}/\d{2,4}\b', "DATE"), | |
| (r'\b\d{1,2}-(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)-\d{2,4}\b', "DATE"), | |
| (r'\b(?:Effective|Commencement|Expiration|Termination)\s+Date\b', "DATE_REF"), | |
| (r'\$\s?\d{1,3}(?:,\d{3})*(?:\.\d{2})?(?:\s*(?:million|billion|thousand|M|B|K))?', "MONEY"), | |
| (r'\b\d{1,3}(?:,\d{3})*(?:\.\d{2})?\s*(?:USD|EUR|GBP|dollars|euros|pounds)', "MONEY"), | |
| (r'\b(?:USD|EUR|GBP)\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})?', "MONEY"), | |
| (r'\b\d+(?:\.\d+)?%', "PERCENTAGE"), | |
| (r'\b\d+\s*(?:year|month|week|day|business day)s?\b', "DURATION"), | |
| (r'\b[A-Z][A-Za-z0-9\s&,]+?(?:Inc\.?|LLC|Ltd\.?|Limited|Corp\.?|Corporation|PLC|GmbH|AG|S\.A\.?|B\.V\.?|L\.P\.?|LLP)\b', "PARTY"), | |
| (r'\b(?:Party A|Party B|Disclosing Party|Receiving Party|Licensor|Licensee|Buyer|Seller|Tenant|Landlord|Employer|Employee|Customer|Vendor|Client)\b', "PARTY_ROLE"), | |
| (r'\b(?:State|Commonwealth)\s+of\s+[A-Z][a-zA-Z\s]+', "JURISDICTION"), | |
| (r'\b(?:California|Delaware|New York|Texas|Florida|England|Ireland|Germany|France|Singapore|Hong Kong|Ontario|British Columbia)\b', "JURISDICTION"), | |
| (r'"([A-Z][A-Za-z\s]{1,40})"', "DEFINED_TERM"), | |
| (r'\((?:the\s+)?"([A-Z][A-Za-z\s]{1,40})"\)', "DEFINED_TERM"), | |
| ] | |
| for pat, etype in patterns: | |
| for m in re.finditer(pat, text, re.IGNORECASE if etype in ("DATE", "MONEY", "DURATION", "PERCENTAGE") else 0): | |
| txt = m.group(1) if m.lastindex else m.group() | |
| entities.append({ | |
| "text": txt, | |
| "type": etype, | |
| "start": m.start(), | |
| "end": m.end(), | |
| "source": "pattern", | |
| }) | |
| return entities | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 7. NLI / CONTRADICTION DETECTION | |
| # FIX v4.1: Pass (text_a, text_b) as dict with proper keys for | |
| # cross-encoder pipeline, not [SEP]-concatenated string | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _run_nli(text_a, text_b): | |
| """Run NLI using CrossEncoder with correct input format. | |
| FIX v4.2: Use sentence_transformers.CrossEncoder.predict() which accepts | |
| a list of (text_a, text_b) tuples. Returns scores for [contradiction, entailment, neutral]. | |
| The old code used pipeline("text-classification") with dict input, which was broken.""" | |
| try: | |
| # CrossEncoder.predict returns numpy array of shape (n_pairs, 3) | |
| # Columns: [contradiction, entailment, neutral] | |
| scores = nli_model.predict([(text_a[:256], text_b[:256])]) | |
| label_mapping = ["contradiction", "entailment", "neutral"] | |
| top_idx = int(scores[0].argmax()) | |
| top_score = float(scores[0][top_idx]) | |
| return [{"label": label_mapping[top_idx], "score": top_score}] | |
| except Exception as e: | |
| print(f"[ClauseGuard] NLI inference error: {e}") | |
| return None | |
| def detect_contradictions(clause_results, raw_text=""): | |
| """ | |
| Detect contradictions using: | |
| 1. NLI cross-encoder model (semantic contradiction detection) | |
| 2. Structural conflict detection (mutually exclusive labels) | |
| 3. Missing critical clause detection | |
| """ | |
| contradictions = [] | |
| labels_found = set() | |
| clause_texts_by_label = defaultdict(list) | |
| for cr in clause_results: | |
| labels_found.add(cr["label"]) | |
| clause_texts_by_label[cr["label"]].append(cr.get("text", "")) | |
| # ββ 1. Semantic NLI (if model available) ββ | |
| if _HAS_NLI_MODEL and nli_model is not None: | |
| conflict_pairs = [ | |
| ("Uncapped Liability", "Cap on Liability", | |
| "Liability cannot be both uncapped and capped simultaneously."), | |
| ("IP Ownership Assignment", "Joint IP Ownership", | |
| "IP cannot be both fully assigned and jointly owned."), | |
| ("Exclusivity", "Non-Transferable License", | |
| "Exclusivity and non-transferable license may conflict."), | |
| ] | |
| for label_a, label_b, explanation in conflict_pairs: | |
| if label_a in labels_found and label_b in labels_found: | |
| texts_a = clause_texts_by_label[label_a] | |
| texts_b = clause_texts_by_label[label_b] | |
| for ta in texts_a[:2]: | |
| for tb in texts_b[:2]: | |
| # FIX v4.1: Use proper NLI input format | |
| nli_result = _run_nli(ta, tb) | |
| if nli_result is None: | |
| continue | |
| for r in (nli_result if isinstance(nli_result, list) else [nli_result]): | |
| if r.get("label", "").lower() == "contradiction" and r.get("score", 0) > 0.6: | |
| contradictions.append({ | |
| "type": "CONTRADICTION", | |
| "explanation": explanation, | |
| "severity": "HIGH", | |
| "clauses": [label_a, label_b], | |
| "confidence": round(r["score"], 3), | |
| "source": "nli_model", | |
| }) | |
| # Also check for internal contradictions within governing law / termination | |
| for label in ["Governing Law", "Termination for Convenience"]: | |
| texts = clause_texts_by_label.get(label, []) | |
| if len(texts) >= 2: | |
| for i in range(len(texts)): | |
| for j in range(i + 1, min(len(texts), i + 3)): | |
| nli_result = _run_nli(texts[i], texts[j]) | |
| if nli_result is None: | |
| continue | |
| for r in (nli_result if isinstance(nli_result, list) else [nli_result]): | |
| if r.get("label", "").lower() == "contradiction" and r.get("score", 0) > 0.6: | |
| contradictions.append({ | |
| "type": "CONTRADICTION", | |
| "explanation": f"Conflicting {label} provisions detected β clauses contradict each other.", | |
| "severity": "HIGH", | |
| "clauses": [label], | |
| "confidence": round(r["score"], 3), | |
| "source": "nli_model", | |
| }) | |
| else: | |
| # ββ Heuristic fallback (improved) ββ | |
| _heuristic_pairs = [ | |
| (["Uncapped Liability"], ["Cap on Liability"], | |
| "Liability cannot be both uncapped and capped simultaneously."), | |
| (["IP Ownership Assignment"], ["Joint IP Ownership"], | |
| "IP cannot be both fully assigned and jointly owned."), | |
| ] | |
| for group_a, group_b, explanation in _heuristic_pairs: | |
| found_a = any(l in labels_found for l in group_a) | |
| found_b = any(l in labels_found for l in group_b) | |
| if found_a and found_b: | |
| contradictions.append({ | |
| "type": "CONTRADICTION", | |
| "explanation": explanation, | |
| "severity": "HIGH", | |
| "clauses": group_a + group_b, | |
| "source": "heuristic", | |
| }) | |
| # ββ 2. Missing critical clauses ββ | |
| _REQUIRED_CLAUSE_PATTERNS = { | |
| "Governing Law": re.compile( | |
| r'govern(?:ed|ing).{0,15}law|applicable.{0,10}law|laws?\s+of\s+the\s+state', | |
| re.IGNORECASE | |
| ), | |
| "Limitation of liability": re.compile( | |
| r'limitation.{0,10}liabilit|cap.{0,10}liabilit|liabilit.{0,10}shall\s+not\s+exceed|in\s+no\s+event.{0,20}liable', | |
| re.IGNORECASE | |
| ), | |
| "Arbitration": re.compile( | |
| r'arbitrat|AAA|JAMS|binding.{0,10}dispute', | |
| re.IGNORECASE | |
| ), | |
| "Termination": re.compile( | |
| r'terminat(?:e|ion|ed)|cancel(?:lation)?', | |
| re.IGNORECASE | |
| ), | |
| } | |
| for clause_name, pattern in _REQUIRED_CLAUSE_PATTERNS.items(): | |
| if not pattern.search(raw_text): | |
| contradictions.append({ | |
| "type": "MISSING", | |
| "explanation": f"No '{clause_name}' clause detected in the document.", | |
| "severity": "MEDIUM", | |
| "clauses": [clause_name], | |
| "source": "structural", | |
| }) | |
| # Deduplicate | |
| seen = set() | |
| unique = [] | |
| for c in contradictions: | |
| key = (c["type"], c["explanation"]) | |
| if key not in seen: | |
| seen.add(key) | |
| unique.append(c) | |
| return unique | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 8. RISK SCORING | |
| # FIX v4.1: Absolute risk based on findings, not normalized by doc length | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_risk_score(clause_results, total_clauses): | |
| sev_counts = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0} | |
| for cr in clause_results: | |
| sev = cr.get("risk", "LOW") | |
| sev_counts[sev] += 1 | |
| if total_clauses == 0: | |
| return 0, "A", sev_counts | |
| # FIX v4.1: Absolute risk β critical findings should always score high | |
| # regardless of document size. A 200-clause doc with 5 critical findings | |
| # is just as dangerous as a 10-clause doc with 5 critical findings. | |
| weighted = sum(sev_counts[s] * RISK_WEIGHTS[s] for s in sev_counts) | |
| # Diminishing returns formula: starts linear, flattens near 100 | |
| # max theoretical = 100, one CRITICAL finding = ~30, two = ~48, five = ~72 | |
| risk = min(100, round(100 * (1 - (1 / (1 + weighted / 30))))) | |
| if risk >= 70: grade = "F" | |
| elif risk >= 50: grade = "D" | |
| elif risk >= 30: grade = "C" | |
| elif risk >= 15: grade = "B" | |
| else: grade = "A" | |
| return risk, grade, sev_counts | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 9. MAIN ANALYSIS PIPELINE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def analyze_contract(text): | |
| if not text or len(text.strip()) < 50: | |
| return None, "Document too short (minimum 50 characters)" | |
| clauses = split_clauses(text) | |
| if not clauses: | |
| return None, "No clauses detected in document" | |
| # PERF v4.3: Batch classification β single forward pass instead of per-clause | |
| batch_predictions = classify_cuad_batch(clauses, batch_size=8) | |
| clause_results = [] | |
| for clause, predictions in zip(clauses, batch_predictions): | |
| if predictions: | |
| for pred in predictions: | |
| clause_results.append({ | |
| "text": clause, | |
| "label": pred["label"], | |
| "confidence": pred["confidence"], | |
| "risk": pred["risk"], | |
| "description": pred["description"], | |
| "source": pred.get("source", "unknown"), | |
| }) | |
| entities = extract_entities(text) | |
| contradictions = detect_contradictions(clause_results, text) | |
| risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses)) | |
| obligations = extract_obligations(text) | |
| compliance = check_compliance(text) | |
| flagged_clause_count = len(clause_results) | |
| unique_flagged_texts = len(set(cr["text"] for cr in clause_results)) | |
| result = { | |
| "metadata": { | |
| "analysis_date": datetime.now().isoformat(), | |
| "total_clauses": len(clauses), | |
| "flagged_clauses": flagged_clause_count, | |
| "unique_flagged": unique_flagged_texts, | |
| "model": get_model_status_text(), | |
| "text_hash": hashlib.sha256(re.sub(r'\s+', ' ', text.strip()).encode()).hexdigest()[:16], | |
| }, | |
| "risk": { | |
| "score": risk, | |
| "grade": grade, | |
| "breakdown": sev_counts, | |
| }, | |
| "clauses": clause_results, | |
| "entities": entities, | |
| "contradictions": contradictions, | |
| "obligations": obligations, | |
| "compliance": compliance, | |
| "raw_text": text, | |
| } | |
| return result, None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 10. EXPORT FUNCTIONS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def export_json(result): | |
| if result is None: | |
| return None | |
| return json.dumps(result, indent=2, default=str) | |
| def export_csv(result): | |
| if result is None: | |
| return None | |
| output = io.StringIO() | |
| writer = csv.writer(output) | |
| writer.writerow(["Clause Text", "Label", "Risk", "Confidence", "Description", "Source"]) | |
| for cr in result.get("clauses", []): | |
| conf = cr.get("confidence") | |
| conf_str = f"{conf:.3f}" if conf is not None else "pattern match" | |
| writer.writerow([ | |
| cr.get("text", "")[:500], | |
| cr.get("label", ""), | |
| cr.get("risk", ""), | |
| conf_str, | |
| cr.get("description", ""), | |
| cr.get("source", ""), | |
| ]) | |
| return output.getvalue() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 11. UI RENDERING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def render_summary(result): | |
| if result is None: | |
| return "" | |
| risk = result["risk"] | |
| score = risk["score"] | |
| grade = risk["grade"] | |
| breakdown = risk["breakdown"] | |
| grade_color = { | |
| "A": "#16a34a", "B": "#65a30d", "C": "#ca8a04", | |
| "D": "#ea580c", "F": "#dc2626", | |
| }.get(grade, "#6b7280") | |
| crit, high, med, low = breakdown["CRITICAL"], breakdown["HIGH"], breakdown["MEDIUM"], breakdown["LOW"] | |
| html = f""" | |
| <div style="font-family:system-ui,sans-serif;padding:16px;border:1px solid #e5e7eb;border-radius:12px;background:#fff;"> | |
| <div style="text-align:center;margin-bottom:16px;"> | |
| <div style="font-size:48px;font-weight:700;color:{grade_color};">{score}</div> | |
| <div style="font-size:14px;color:#6b7280;">/100 Risk Score</div> | |
| <div style="display:inline-block;margin-top:8px;padding:4px 16px;border-radius:20px;background:{grade_color};color:white;font-weight:600;font-size:14px;"> | |
| Grade {grade} | |
| </div> | |
| </div> | |
| <div style="display:grid;grid-template-columns:1fr 1fr;gap:8px;margin-bottom:12px;"> | |
| <div style="padding:8px;border-radius:6px;background:#fef2f2;text-align:center;"> | |
| <div style="font-size:20px;font-weight:700;color:#dc2626;">{crit}</div> | |
| <div style="font-size:11px;color:#991b1b;">Critical</div> | |
| </div> | |
| <div style="padding:8px;border-radius:6px;background:#fff7ed;text-align:center;"> | |
| <div style="font-size:20px;font-weight:700;color:#ea580c;">{high}</div> | |
| <div style="font-size:11px;color:#9a3412;">High</div> | |
| </div> | |
| <div style="padding:8px;border-radius:6px;background:#fefce8;text-align:center;"> | |
| <div style="font-size:20px;font-weight:700;color:#ca8a04;">{med}</div> | |
| <div style="font-size:11px;color:#854d0e;">Medium</div> | |
| </div> | |
| <div style="padding:8px;border-radius:6px;background:#f0fdf4;text-align:center;"> | |
| <div style="font-size:20px;font-weight:700;color:#16a34a;">{low}</div> | |
| <div style="font-size:11px;color:#166534;">Low</div> | |
| </div> | |
| </div> | |
| <div style="font-size:12px;color:#6b7280;text-align:center;"> | |
| {result['metadata']['total_clauses']} clauses analyzed Β· {result['metadata']['flagged_clauses']} flagged | |
| <br><span style="font-size:10px;">{result['metadata']['model']}</span> | |
| </div> | |
| </div> | |
| """ | |
| return html | |
| def render_clause_cards(result): | |
| if result is None: | |
| return "" | |
| clauses = result.get("clauses", []) | |
| if not clauses: | |
| return '<div style="padding:24px;text-align:center;color:#6b7280;">No clauses detected.</div>' | |
| grouped = defaultdict(list) | |
| for cr in clauses: | |
| grouped[cr["text"]].append(cr) | |
| html = '<div style="font-family:system-ui,sans-serif;">' | |
| for text, items in grouped.items(): | |
| max_risk = max(items, key=lambda x: {"CRITICAL":4,"HIGH":3,"MEDIUM":2,"LOW":1}[x["risk"]])["risk"] | |
| border, bg, icon = RISK_STYLES[max_risk] | |
| tags = "" | |
| for item in items: | |
| tag_bg = RISK_STYLES[item["risk"]][1] | |
| tag_color = RISK_STYLES[item["risk"]][0] | |
| conf = item.get("confidence") | |
| source = item.get("source", "") | |
| if conf is not None: | |
| conf_text = f"{conf:.0%}" | |
| else: | |
| conf_text = "pattern" | |
| source_icon = "π€" if source == "ml" else "π" | |
| tags += f'<span style="background:{tag_bg};color:{tag_color};border:1px solid {tag_color}33;padding:2px 8px;border-radius:12px;font-size:11px;font-weight:500;margin-right:4px;">{source_icon} {item["label"]} ({conf_text})</span>' | |
| descs = "".join( | |
| f'<p style="font-size:12px;color:#6b7280;margin:4px 0 0 0;">{item["description"]}</p>' | |
| for item in items | |
| ) | |
| preview = text[:300] + ("..." if len(text) > 300 else "") | |
| preview = preview.replace("<", "<").replace(">", ">") | |
| html += f""" | |
| <div style="border:1px solid #e5e7eb;border-left:4px solid {border};border-radius:8px;padding:14px;margin-bottom:10px;background:#fafafa;"> | |
| <div style="display:flex;align-items:center;gap:6px;margin-bottom:6px;"> | |
| <span style="font-size:16px;">{icon}</span> | |
| <span style="font-size:12px;font-weight:600;color:{border};text-transform:uppercase;">{max_risk}</span> | |
| </div> | |
| <p style="font-size:13px;color:#374151;line-height:1.6;margin:0 0 8px 0;">{preview}</p> | |
| <div style="margin-bottom:6px;">{tags}</div> | |
| {descs} | |
| </div> | |
| """ | |
| html += "</div>" | |
| return html | |
| def render_entities(result): | |
| if result is None: | |
| return "" | |
| entities = result.get("entities", []) | |
| if not entities: | |
| return '<div style="padding:16px;color:#6b7280;">No entities detected.</div>' | |
| grouped = defaultdict(list) | |
| for e in entities: | |
| grouped[e["type"]].append(e["text"]) | |
| html = '<div style="font-family:system-ui,sans-serif;">' | |
| for etype, texts in grouped.items(): | |
| unique = list(dict.fromkeys(texts))[:20] | |
| color = { | |
| "DATE": "#3b82f6", "DATE_REF": "#60a5fa", | |
| "MONEY": "#22c55e", "PERCENTAGE": "#10b981", | |
| "DURATION": "#6366f1", | |
| "PARTY": "#8b5cf6", "PARTY_ROLE": "#a78bfa", | |
| "PERSON": "#ec4899", | |
| "JURISDICTION": "#f59e0b", | |
| "DEFINED_TERM": "#ec4899", | |
| "LEGAL_REF": "#6b7280", | |
| "MISC": "#9ca3af", | |
| }.get(etype, "#6b7280") | |
| items_html = "".join( | |
| f'<span style="display:inline-block;background:{color}15;color:{color};border:1px solid {color}40;padding:3px 10px;border-radius:6px;font-size:12px;margin:3px;">{t}</span>' | |
| for t in unique | |
| ) | |
| html += f""" | |
| <div style="margin-bottom:12px;"> | |
| <div style="font-size:12px;font-weight:600;color:#374151;margin-bottom:6px;text-transform:uppercase;">{etype}</div> | |
| <div>{items_html}</div> | |
| </div> | |
| """ | |
| html += "</div>" | |
| return html | |
| def render_contradictions(result): | |
| if result is None: | |
| return "" | |
| contradictions = result.get("contradictions", []) | |
| if not contradictions: | |
| return '<div style="padding:16px;color:#16a34a;">β No contradictions or missing clauses detected.</div>' | |
| html = '<div style="font-family:system-ui,sans-serif;">' | |
| for c in contradictions: | |
| sev_color = RISK_STYLES[c["severity"]][0] | |
| icon = "β οΈ" if c["type"] == "CONTRADICTION" else "π" | |
| source = c.get("source", "") | |
| source_badge = "" | |
| if source == "nli_model": | |
| conf = c.get("confidence", 0) | |
| source_badge = f'<span style="font-size:10px;background:#eff6ff;color:#3b82f6;padding:1px 6px;border-radius:4px;margin-left:8px;">π€ NLI {conf:.0%}</span>' | |
| elif source == "heuristic": | |
| source_badge = '<span style="font-size:10px;background:#fef3c7;color:#92400e;padding:1px 6px;border-radius:4px;margin-left:8px;">π Heuristic</span>' | |
| html += f""" | |
| <div style="border:1px solid #e5e7eb;border-left:4px solid {sev_color};border-radius:8px;padding:12px;margin-bottom:8px;background:#fafafa;"> | |
| <div style="display:flex;align-items:center;gap:6px;margin-bottom:4px;"> | |
| <span>{icon}</span> | |
| <span style="font-size:12px;font-weight:600;color:{sev_color};">{c["type"]}</span> | |
| {source_badge} | |
| </div> | |
| <p style="font-size:13px;color:#374151;margin:0;">{c["explanation"]}</p> | |
| </div> | |
| """ | |
| html += "</div>" | |
| return html | |
| def render_document_viewer(result): | |
| if result is None: | |
| return "" | |
| text = result.get("raw_text", "") | |
| entities = sorted(result.get("entities", []), key=lambda x: x["start"]) | |
| html_parts = [] | |
| last_end = 0 | |
| entity_colors = { | |
| "DATE": "#3b82f6", "DATE_REF": "#60a5fa", "MONEY": "#22c55e", | |
| "PERCENTAGE": "#10b981", "DURATION": "#6366f1", "PARTY": "#8b5cf6", | |
| "PARTY_ROLE": "#a78bfa", "PERSON": "#ec4899", "JURISDICTION": "#f59e0b", | |
| "DEFINED_TERM": "#ec4899", "LEGAL_REF": "#6b7280", "MISC": "#9ca3af", | |
| } | |
| for e in entities: | |
| if e["start"] >= last_end: | |
| plain = text[last_end:e["start"]].replace("<", "<").replace(">", ">") | |
| html_parts.append(plain) | |
| color = entity_colors.get(e["type"], "#6b7280") | |
| entity_text = text[e["start"]:e["end"]].replace("<", "<").replace(">", ">") | |
| html_parts.append( | |
| f'<span style="background:{color}20;color:{color};border-bottom:2px solid {color};padding:0 2px;border-radius:2px;" ' | |
| f'title="{e["type"]}">{entity_text}</span>' | |
| ) | |
| last_end = e["end"] | |
| if last_end < len(text): | |
| html_parts.append(text[last_end:].replace("<", "<").replace(">", ">")) | |
| return f'<div style="font-family:ui-monospace,monospace;font-size:13px;line-height:1.8;white-space:pre-wrap;padding:16px;">{"".join(html_parts)}</div>' | |