ClauseGuard / compare.py
gaurv007's picture
⚑ v4.3: Performance optimizations β€” ONNX INT8, BGE embedder, batched classification, thread control (#4)
f4b6528
"""
ClauseGuard β€” Contract Comparison Engine v3.1
═════════════════════════════════════════════
FIXED in v3.1:
β€’ PERF: Pre-compute all embeddings once, use matrix multiplication (was O(nΒ²) per-pair encoding)
β€’ FIX: Shared SentenceTransformer singleton (no duplicate model loading)
β€’ FIX: Raised similarity thresholds to reduce false matches
"""
import re
from difflib import SequenceMatcher
from collections import defaultdict
import numpy as np
# Try to load sentence-transformers for semantic comparison
_HAS_EMBEDDINGS = False
_embedder = None
try:
from sentence_transformers import SentenceTransformer
_HAS_EMBEDDINGS = True
except ImportError:
pass
def _load_embedder():
"""Load shared SentenceTransformer singleton.
PERF v4.3: Upgraded to BAAI/bge-small-en-v1.5 (+21% retrieval accuracy)."""
global _embedder
if _HAS_EMBEDDINGS and _embedder is None:
try:
_embedder = SentenceTransformer("BAAI/bge-small-en-v1.5")
print("[ClauseGuard] Sentence embeddings loaded for comparison (BGE-small)")
except Exception as e:
print(f"[ClauseGuard] Embeddings not available: {e}")
def _normalize_clause(text):
"""Normalize clause text for comparison."""
text = text.lower()
text = re.sub(r'[^a-z0-9\s]', ' ', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def _compute_similarity_matrix(clauses_a, clauses_b):
"""
FIX v3.1: Compute similarity matrix using pre-computed embeddings + matrix multiply.
Was: O(nΒ²) individual encode() calls per pair.
Now: O(n+m) encode calls + O(n*m) dot product (fast numpy).
"""
if _embedder is not None:
try:
# Encode all clauses at once (batched)
texts_a = [c[:512] for c in clauses_a]
texts_b = [c[:512] for c in clauses_b]
emb_a = _embedder.encode(texts_a, normalize_embeddings=True, batch_size=32, show_progress_bar=False)
emb_b = _embedder.encode(texts_b, normalize_embeddings=True, batch_size=32, show_progress_bar=False)
# Cosine similarity via dot product (embeddings are L2-normalized)
sim_matrix = np.dot(emb_a, emb_b.T)
return sim_matrix, "semantic"
except Exception:
pass
# Fallback: string matching (still compute matrix)
n, m = len(clauses_a), len(clauses_b)
sim_matrix = np.zeros((n, m))
for i in range(n):
norm_a = _normalize_clause(clauses_a[i])
for j in range(m):
norm_b = _normalize_clause(clauses_b[j])
sim_matrix[i, j] = SequenceMatcher(None, norm_a, norm_b).ratio()
return sim_matrix, "lexical"
def _extract_clause_type(clause_text):
"""Clause type detection with legal taxonomy."""
text_lower = clause_text.lower()
type_keywords = {
"governing law": ["govern", "law of", "jurisdiction of", "applicable law"],
"termination": ["terminat", "cancel", "expir"],
"indemnification": ["indemnif", "hold harmless", "defend and indemnify"],
"confidentiality": ["confidential", "non-disclosure", "nda", "proprietary"],
"liability": ["liability", "liable", "damages", "limitation of"],
"payment": ["payment", "fee", "price", "compensat", "invoice", "remit"],
"intellectual property": ["intellectual property", "ip rights", "copyright", "patent", "trademark"],
"warranty": ["warrant", "guarantee", "representation"],
"force majeure": ["force majeure", "act of god", "beyond control"],
"arbitration": ["arbitrat", "mediation", "dispute resolution"],
"assignment": ["assign", "transfer of rights"],
"non-compete": ["non-compete", "not compete", "competition"],
"renewal": ["renew", "extend", "automatic renewal"],
"effective date": ["effective date", "commencement"],
"insurance": ["insurance", "coverage", "policy of insurance"],
"audit": ["audit", "inspection", "examination of records"],
"data protection": ["data protection", "privacy", "personal data", "gdpr", "ccpa"],
"notice": ["notice", "notification", "written notice"],
}
for ctype, keywords in type_keywords.items():
if any(kw in text_lower for kw in keywords):
return ctype
return "general"
def compare_contracts(text_a, text_b, clauses_a=None, clauses_b=None):
"""Compare two contracts with semantic similarity."""
if not text_a or not text_b:
return {"error": "Both contracts required"}
_load_embedder()
if clauses_a is None:
clauses_a = _split_clauses(text_a)
if clauses_b is None:
clauses_b = _split_clauses(text_b)
# Detect contract types and flag cross-domain comparisons
_CONTRACT_TYPE_KEYWORDS = {
"employment": ["employee", "employer", "salary", "compensation", "benefits", "vacation", "severance", "at-will"],
"lease": ["landlord", "tenant", "rent", "premises", "lease", "occupancy", "security deposit", "eviction"],
"service": ["service provider", "customer", "SLA", "deliverables", "statement of work", "SOW"],
"nda": ["confidential", "non-disclosure", "disclosing party", "receiving party"],
"saas": ["subscription", "SaaS", "cloud", "uptime", "API", "data processing"],
"purchase": ["buyer", "seller", "purchase order", "goods", "shipment", "delivery"],
}
def _detect_contract_type(text):
text_lower = text.lower()
scores = {}
for ctype, keywords in _CONTRACT_TYPE_KEYWORDS.items():
scores[ctype] = sum(1 for kw in keywords if kw.lower() in text_lower)
best = max(scores, key=scores.get)
return best if scores[best] >= 2 else "general"
type_a = _detect_contract_type(text_a)
type_b = _detect_contract_type(text_b)
is_cross_domain = type_a != type_b and type_a != "general" and type_b != "general"
# Build clause type maps
type_map_a = defaultdict(list)
type_map_b = defaultdict(list)
for c in clauses_a:
type_map_a[_extract_clause_type(c)].append(c)
for c in clauses_b:
type_map_b[_extract_clause_type(c)].append(c)
# FIX v3.1: Compute similarity matrix once (O(n+m) encoding + O(n*m) dot product)
if clauses_a and clauses_b:
sim_matrix, method_type = _compute_similarity_matrix(clauses_a, clauses_b)
else:
sim_matrix = np.zeros((0, 0))
method_type = "none"
# Find matches using the pre-computed matrix
matched_a = set()
matched_b = set()
modified = []
SIMILARITY_THRESHOLD = 0.75
MODIFIED_THRESHOLD = 0.55
for i in range(len(clauses_a)):
if len(clauses_b) == 0:
break
# Find best match for clause i in A
row = sim_matrix[i]
# Mask already-matched B clauses
available = np.ones(len(clauses_b), dtype=bool)
for j in matched_b:
available[j] = False
if not available.any():
break
masked_row = np.where(available, row, -1.0)
best_j = int(np.argmax(masked_row))
best_sim = masked_row[best_j]
if best_sim >= SIMILARITY_THRESHOLD:
matched_a.add(i)
matched_b.add(best_j)
if best_sim < 0.95:
modified.append({
"type": "modified",
"similarity": round(float(best_sim), 3),
"clause_a": clauses_a[i][:200],
"clause_b": clauses_b[best_j][:200],
"clause_type": _extract_clause_type(clauses_a[i]),
})
elif best_sim >= MODIFIED_THRESHOLD:
matched_a.add(i)
matched_b.add(best_j)
modified.append({
"type": "partial",
"similarity": round(float(best_sim), 3),
"clause_a": clauses_a[i][:200],
"clause_b": clauses_b[best_j][:200],
"clause_type": _extract_clause_type(clauses_a[i]),
})
removed = [clauses_a[i] for i in range(len(clauses_a)) if i not in matched_a]
added = [clauses_b[j] for j in range(len(clauses_b)) if j not in matched_b]
# Compute alignment score
total_pairs = max(len(clauses_a), len(clauses_b))
alignment = len(matched_a) / total_pairs if total_pairs > 0 else 0.0
# Risk delta
risk_keywords = ["unlimited", "unilateral", "waive", "arbitration", "indemnif",
"not liable", "no warranty", "sole discretion", "terminate",
"non-compete", "liquidated damages", "uncapped"]
risk_a = sum(1 for kw in risk_keywords if kw in text_a.lower())
risk_b = sum(1 for kw in risk_keywords if kw in text_b.lower())
if risk_a > risk_b + 2:
risk_delta = "Contract A is significantly riskier"
risk_winner = "B"
elif risk_b > risk_a + 2:
risk_delta = "Contract B is significantly riskier"
risk_winner = "A"
elif risk_a > risk_b:
risk_delta = "Contract A is slightly riskier"
risk_winner = "B"
elif risk_b > risk_a:
risk_delta = "Contract B is slightly riskier"
risk_winner = "A"
else:
risk_delta = "Similar risk profiles"
risk_winner = "tie"
if is_cross_domain:
risk_delta = f"Cross-domain comparison ({type_a} vs {type_b}) β€” risk delta not meaningful across different contract types"
risk_winner = "cross-domain"
comparison_method = f"semantic (sentence embeddings)" if method_type == "semantic" else "lexical (string matching)"
return {
"alignment_score": round(alignment, 3),
"contract_a_clauses": len(clauses_a),
"contract_b_clauses": len(clauses_b),
"contract_a_type": type_a,
"contract_b_type": type_b,
"is_cross_domain": is_cross_domain,
"added_clauses": [{"text": c[:200], "type": _extract_clause_type(c)} for c in added[:50]],
"removed_clauses": [{"text": c[:200], "type": _extract_clause_type(c)} for c in removed[:50]],
"modified_clauses": modified[:50],
"risk_delta": risk_delta,
"risk_winner": risk_winner,
"comparison_method": comparison_method,
"type_map_a": {k: len(v) for k, v in type_map_a.items()},
"type_map_b": {k: len(v) for k, v in type_map_b.items()},
}
def _split_clauses(text):
"""Split text into clauses."""
text = re.sub(r'\n{3,}', '\n\n', text.strip())
section_splits = re.split(
r'(?:\n\n)(?=\d+[.)]\s|\([a-z]\)\s|(?:Section|Article|Clause)\s+\d+)',
text
)
if len(section_splits) >= 3:
return [p.strip() for p in section_splits if len(p.strip()) > 30]
parts = re.split(
r'(?<=[.!?])\s+(?=[A-Z0-9(])|(?:\n\n)',
text
)
return [p.strip() for p in parts if len(p.strip()) > 30]
def render_comparison_html(result):
"""Render comparison results as HTML for Gradio."""
if "error" in result:
return f'<p style="color:#dc2626;">{result["error"]}</p>'
method = result.get("comparison_method", "unknown")
method_badge = f'<div style="font-size:10px;color:#6b7280;text-align:center;margin-bottom:12px;">Comparison method: {method}</div>'
html = f'''
<div style="font-family:system-ui,sans-serif;">
{method_badge}
<div style="display:grid;grid-template-columns:1fr 1fr;gap:12px;margin-bottom:16px;">
<div style="padding:12px;border-radius:8px;background:#eff6ff;border:1px solid #bfdbfe;text-align:center;">
<div style="font-size:24px;font-weight:700;color:#1d4ed8;">{result["contract_a_clauses"]}</div>
<div style="font-size:12px;color:#3b82f6;">Clauses in Contract A</div>
</div>
<div style="padding:12px;border-radius:8px;background:#fefce8;border:1px solid #fde68a;text-align:center;">
<div style="font-size:24px;font-weight:700;color:#a16207;">{result["contract_b_clauses"]}</div>
<div style="font-size:12px;color:#ca8a04;">Clauses in Contract B</div>
</div>
</div>
<div style="padding:12px;border-radius:8px;background:#f9fafb;border:1px solid #e5e7eb;margin-bottom:16px;text-align:center;">
<div style="font-size:28px;font-weight:700;color:#374151;">{result["alignment_score"]*100:.1f}%</div>
<div style="font-size:12px;color:#6b7280;">Alignment Score</div>
</div>
<div style="padding:12px;border-radius:8px;background:{
"#fef2f2" if result["risk_winner"] != "tie" else "#f0fdf4"
};border:1px solid {
"#fecaca" if result["risk_winner"] != "tie" else "#bbf7d0"
};margin-bottom:16px;text-align:center;">
<span style="font-size:14px;font-weight:600;color:{
"#dc2626" if result["risk_winner"] != "tie" else "#16a34a"
};">βš–οΈ {result["risk_delta"]}</span>
</div>
'''
if result["modified_clauses"]:
html += '<div style="margin-bottom:16px;"><h3 style="font-size:14px;color:#374151;margin-bottom:8px;">πŸ“ Modified Clauses</h3>'
for m in result["modified_clauses"][:20]:
html += f'''
<div style="border:1px solid #e5e7eb;border-radius:6px;padding:10px;margin-bottom:8px;">
<div style="font-size:11px;color:#6b7280;margin-bottom:4px;">{m["clause_type"].upper()} Β· Similarity: {m["similarity"]*100:.0f}%</div>
<div style="display:grid;grid-template-columns:1fr 1fr;gap:8px;">
<div style="background:#fef2f2;padding:6px;border-radius:4px;font-size:12px;color:#991b1b;">{m["clause_a"][:150]}...</div>
<div style="background:#f0fdf4;padding:6px;border-radius:4px;font-size:12px;color:#166534;">{m["clause_b"][:150]}...</div>
</div>
</div>
'''
html += '</div>'
if result["added_clauses"]:
html += '<div style="margin-bottom:16px;"><h3 style="font-size:14px;color:#374151;margin-bottom:8px;">βž• Added in Contract B</h3>'
for a in result["added_clauses"][:15]:
html += f'<div style="background:#f0fdf4;padding:8px;border-radius:4px;font-size:12px;color:#166534;margin-bottom:4px;border-left:3px solid #22c55e;"><b>{a["type"].upper()}</b> Β· {a["text"][:150]}...</div>'
html += '</div>'
if result["removed_clauses"]:
html += '<div style="margin-bottom:16px;"><h3 style="font-size:14px;color:#374151;margin-bottom:8px;">βž– Removed from Contract A</h3>'
for r in result["removed_clauses"][:15]:
html += f'<div style="background:#fef2f2;padding:8px;border-radius:4px;font-size:12px;color:#991b1b;margin-bottom:4px;border-left:3px solid #ef4444;"><b>{r["type"].upper()}</b> Β· {r["text"][:150]}...</div>'
html += '</div>'
html += '</div>'
return html