Spaces:
Sleeping
Sleeping
Merge branch 'main' of https://huggingface.co/spaces/gaurv007/ClauseGuard
Browse files- README.md +73 -14
- api/Dockerfile +8 -2
- api/main.py +181 -85
- api/requirements.txt +7 -4
- app.py +211 -48
- chatbot.py +406 -0
- compare.py +33 -2
- ml/ClauseGuard_DeBERTa_Training.ipynb +1041 -0
- ml/requirements.txt +2 -2
- ml/train_classifier_v4.py +434 -0
- obligations.py +12 -8
- ocr_engine.py +218 -0
- redlining.py +591 -0
- requirements.txt +3 -1
- web/.env.example +7 -0
- web/app/api/analyze/route.ts +12 -2
- web/app/api/chat/route.ts +37 -0
- web/app/api/redline/route.ts +37 -0
- web/app/dashboard-pages/analyze/page.tsx +178 -1
- web/app/page.tsx +25 -20
- web/components/nav.tsx +2 -3
README.md
CHANGED
|
@@ -10,9 +10,17 @@ app_file: app.py
|
|
| 10 |
pinned: false
|
| 11 |
---
|
| 12 |
|
| 13 |
-
# 🛡️ ClauseGuard — World's Best Open-Source Legal Contract Analysis
|
| 14 |
|
| 15 |
-
**ClauseGuard** is the most comprehensive open-source AI-powered legal contract analysis tool. It analyzes contracts using state-of-the-art legal NLP models and provides actionable risk assessments.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
## ✨ Core Features
|
| 18 |
|
|
@@ -26,9 +34,12 @@ pinned: false
|
|
| 26 |
| **Obligation Tracker** | Categorizes action items: monetary 💰, compliance ⚖️, reporting 📊, delivery 📦, termination 🛑 |
|
| 27 |
| **Compliance Checker** | Validates against GDPR, CCPA, SOX, HIPAA, and FINRA requirements |
|
| 28 |
| **Contract Comparison** | Side-by-side diff between two contracts with alignment scoring |
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
### Document Support
|
| 31 |
-
- **PDF** parsing via `pdfplumber`
|
| 32 |
- **DOCX/DOC** parsing via `python-docx`
|
| 33 |
- **TXT / Markdown** direct text input
|
| 34 |
|
|
@@ -36,6 +47,8 @@ pinned: false
|
|
| 36 |
- **3-Panel Professional Layout** — Upload sidebar + Main analysis + Summary dashboard
|
| 37 |
- **Document Viewer** — Inline entity highlights (colored annotations)
|
| 38 |
- **Clause Cards** — Expandable risk-badged cards with confidence scores
|
|
|
|
|
|
|
| 39 |
- **Export Reports** — JSON (structured) and CSV (tabular) downloads
|
| 40 |
- **Color-Coded Risk Badges** — Instant visual triage
|
| 41 |
|
|
@@ -44,12 +57,61 @@ pinned: false
|
|
| 44 |
| Component | Technology |
|
| 45 |
|-----------|------------|
|
| 46 |
| Clause Classification | `Mokshith31/legalbert-contract-clause-classification` — LoRA adapter on `nlpaueb/legal-bert-base-uncased`, fine-tuned on CUAD 41-class taxonomy |
|
| 47 |
-
| NER |
|
| 48 |
-
| NLI |
|
|
|
|
|
|
|
|
|
|
| 49 |
| Compliance | Regulatory keyword matching across GDPR, CCPA, SOX, HIPAA, FINRA |
|
| 50 |
-
| Comparison |
|
| 51 |
| Obligations | Regex pattern matching across 5 obligation categories |
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
## 📊 Risk Scoring Methodology
|
| 54 |
|
| 55 |
Risk scores combine clause detection with weighted severity:
|
|
@@ -65,16 +127,10 @@ Final score normalized to 0-100 with letter grades:
|
|
| 65 |
- D (50-69): High risk
|
| 66 |
- F (70+): Critical risk
|
| 67 |
|
| 68 |
-
## 📚 Datasets & Research
|
| 69 |
-
|
| 70 |
-
- [CUAD](https://huggingface.co/datasets/theatticusproject/cuad-qa) — 510 contracts, 13K annotations, 41 clause categories
|
| 71 |
-
- [LegalBench](https://huggingface.co/datasets/nguha/legalbench) — 322 legal reasoning tasks
|
| 72 |
-
- [LexGLUE](https://huggingface.co/datasets/coastalcph/lex_glue) — Unfair Terms of Service classification
|
| 73 |
-
- Paper: [CUAD: An Expert-Annotated NLP Dataset for Legal Contract Review](https://arxiv.org/abs/2103.06268) (Hendrycks et al., 2021)
|
| 74 |
-
|
| 75 |
## 🚀 Usage
|
| 76 |
|
| 77 |
1. **Upload** a contract (PDF, DOCX, or TXT) or paste text directly
|
|
|
|
| 78 |
2. Click **Analyze Contract**
|
| 79 |
3. View results across tabs:
|
| 80 |
- **Document**: Full text with inline entity highlights
|
|
@@ -83,7 +139,9 @@ Final score normalized to 0-100 with letter grades:
|
|
| 83 |
- **Contradictions**: Conflicting clauses and missing provisions
|
| 84 |
- **Obligations**: Action items categorized by type
|
| 85 |
- **Compliance**: Regulatory framework checks
|
|
|
|
| 86 |
4. **Export** JSON/CSV reports
|
|
|
|
| 87 |
|
| 88 |
## 🔀 Compare Contracts
|
| 89 |
|
|
@@ -91,7 +149,6 @@ Switch to the **Compare Contracts** tab to:
|
|
| 91 |
- Upload or paste two contracts side-by-side
|
| 92 |
- See clause-level diffs (added, removed, modified)
|
| 93 |
- Get an alignment score and risk delta
|
| 94 |
-
- View raw JSON comparison data
|
| 95 |
|
| 96 |
## ⚠️ Disclaimer
|
| 97 |
|
|
@@ -103,6 +160,8 @@ Switch to the **Compare Contracts** tab to:
|
|
| 103 |
- [Clause Classifier Model](https://huggingface.co/Mokshith31/legalbert-contract-clause-classification)
|
| 104 |
- [Legal-BERT Base](https://huggingface.co/nlpaueb/legal-bert-base-uncased)
|
| 105 |
- [CUAD Dataset](https://huggingface.co/datasets/theatticusproject/cuad-qa)
|
|
|
|
|
|
|
| 106 |
- [CUAD Paper (arXiv:2103.06268)](https://arxiv.org/abs/2103.06268)
|
| 107 |
|
| 108 |
---
|
|
|
|
| 10 |
pinned: false
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# 🛡️ ClauseGuard v4.0 — World's Best Open-Source Legal Contract Analysis
|
| 14 |
|
| 15 |
+
**ClauseGuard** is the most comprehensive open-source AI-powered legal contract analysis tool. It analyzes contracts using state-of-the-art legal NLP models and provides actionable risk assessments, Q&A chatbot, clause redlining, and OCR for scanned PDFs.
|
| 16 |
+
|
| 17 |
+
## 🆕 What's New in v4.0
|
| 18 |
+
|
| 19 |
+
| Feature | Description |
|
| 20 |
+
|---------|-------------|
|
| 21 |
+
| **🔍 OCR for Scanned PDFs** | Smart PDF router: auto-detects native vs scanned PDFs. Scanned PDFs are processed via docTR OCR engine (CPU-friendly, ~150MB models) |
|
| 22 |
+
| **💬 Contract Q&A Chatbot** | RAG-powered chatbot that answers questions about your analyzed contract. Uses sentence-transformers for retrieval + Qwen2.5-7B via HF Inference API for generation |
|
| 23 |
+
| **✏️ Clause Redlining** | 3-tier system: (1) Template lookup from 18+ legal templates based on FTC/EU standards, (2) Keyword-based matching, (3) LLM refinement for CRITICAL/HIGH risk clauses |
|
| 24 |
|
| 25 |
## ✨ Core Features
|
| 26 |
|
|
|
|
| 34 |
| **Obligation Tracker** | Categorizes action items: monetary 💰, compliance ⚖️, reporting 📊, delivery 📦, termination 🛑 |
|
| 35 |
| **Compliance Checker** | Validates against GDPR, CCPA, SOX, HIPAA, and FINRA requirements |
|
| 36 |
| **Contract Comparison** | Side-by-side diff between two contracts with alignment scoring |
|
| 37 |
+
| **Clause Redlining** | Suggests safer alternatives for risky clauses with legal citations |
|
| 38 |
+
| **Q&A Chatbot** | Ask questions about your contract using RAG (Retrieval-Augmented Generation) |
|
| 39 |
+
| **OCR Support** | Process scanned PDFs with docTR OCR engine |
|
| 40 |
|
| 41 |
### Document Support
|
| 42 |
+
- **PDF** parsing via `pdfplumber` (native) + `docTR` OCR (scanned)
|
| 43 |
- **DOCX/DOC** parsing via `python-docx`
|
| 44 |
- **TXT / Markdown** direct text input
|
| 45 |
|
|
|
|
| 47 |
- **3-Panel Professional Layout** — Upload sidebar + Main analysis + Summary dashboard
|
| 48 |
- **Document Viewer** — Inline entity highlights (colored annotations)
|
| 49 |
- **Clause Cards** — Expandable risk-badged cards with confidence scores
|
| 50 |
+
- **Redlining Tab** — Side-by-side original vs suggested safer alternatives
|
| 51 |
+
- **Q&A Chat Tab** — Conversational interface to ask questions about the contract
|
| 52 |
- **Export Reports** — JSON (structured) and CSV (tabular) downloads
|
| 53 |
- **Color-Coded Risk Badges** — Instant visual triage
|
| 54 |
|
|
|
|
| 57 |
| Component | Technology |
|
| 58 |
|-----------|------------|
|
| 59 |
| Clause Classification | `Mokshith31/legalbert-contract-clause-classification` — LoRA adapter on `nlpaueb/legal-bert-base-uncased`, fine-tuned on CUAD 41-class taxonomy |
|
| 60 |
+
| Legal NER | `matterstack/legal-bert-ner` (ML) with regex fallback for 7 entity types |
|
| 61 |
+
| NLI | `cross-encoder/nli-deberta-v3-base` (semantic contradiction detection) |
|
| 62 |
+
| Embeddings | `sentence-transformers/all-MiniLM-L6-v2` (384-dim, RAG retrieval) |
|
| 63 |
+
| LLM | `Qwen/Qwen2.5-7B-Instruct` via HF Inference API (chatbot + redlining) |
|
| 64 |
+
| OCR | `docTR` (fast_base + crnn_vgg16_bn) for scanned PDF text extraction |
|
| 65 |
| Compliance | Regulatory keyword matching across GDPR, CCPA, SOX, HIPAA, FINRA |
|
| 66 |
+
| Comparison | Semantic similarity with sentence embeddings + string matching fallback |
|
| 67 |
| Obligations | Regex pattern matching across 5 obligation categories |
|
| 68 |
|
| 69 |
+
## 🔍 OCR Architecture (Smart PDF Router)
|
| 70 |
+
|
| 71 |
+
```
|
| 72 |
+
PDF uploaded
|
| 73 |
+
↓
|
| 74 |
+
[detect_if_scanned] — pdfplumber extracts >50 chars/page?
|
| 75 |
+
↓ ↓
|
| 76 |
+
Native PDF Scanned PDF
|
| 77 |
+
↓ ↓
|
| 78 |
+
pdfplumber docTR OCR (CPU)
|
| 79 |
+
↓ ↓
|
| 80 |
+
Contract text → existing analysis pipeline
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## 💬 Q&A Chatbot Architecture (RAG)
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
User asks question about their contract
|
| 87 |
+
↓
|
| 88 |
+
[1] Embed question with all-MiniLM-L6-v2
|
| 89 |
+
↓
|
| 90 |
+
[2] Retrieve top-5 most relevant chunks from contract
|
| 91 |
+
↓
|
| 92 |
+
[3] Build prompt:
|
| 93 |
+
- System: ClauseGuard analysis results (clauses, entities, risk scores)
|
| 94 |
+
- Context: Retrieved contract chunks (≤2.5K tokens)
|
| 95 |
+
- User question
|
| 96 |
+
↓
|
| 97 |
+
[4] Stream response from Qwen2.5-7B via HF Inference API
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
**Key design:** Analyzed data (clauses, entities, risk scores) goes in the system prompt — NOT through RAG retrieval. Only the raw contract text goes through RAG. This gives the model both structured analysis AND verbatim evidence.
|
| 101 |
+
|
| 102 |
+
## ✏️ Clause Redlining Architecture (3-Tier)
|
| 103 |
+
|
| 104 |
+
| Tier | Method | Speed | Hallucination Risk |
|
| 105 |
+
|------|--------|-------|--------------------|
|
| 106 |
+
| **1. Template Lookup** | 18+ pre-written safe alternatives based on FTC/EU/CFPB standards | Instant | Zero |
|
| 107 |
+
| **2. Keyword Matching** | Match clause text to relevant templates via legal keywords | Instant | Zero |
|
| 108 |
+
| **3. LLM Refinement** | Qwen2.5-7B adapts template to specific clause context | ~3-5s | Low (template-anchored) |
|
| 109 |
+
|
| 110 |
+
Anti-hallucination guardrails:
|
| 111 |
+
- **Template anchor:** LLM can only refine, not generate from scratch
|
| 112 |
+
- **Legal citation:** Every suggestion includes legal basis and consumer standard
|
| 113 |
+
- **Disclaimer:** Clear "Not legal advice" warning
|
| 114 |
+
|
| 115 |
## 📊 Risk Scoring Methodology
|
| 116 |
|
| 117 |
Risk scores combine clause detection with weighted severity:
|
|
|
|
| 127 |
- D (50-69): High risk
|
| 128 |
- F (70+): Critical risk
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
## 🚀 Usage
|
| 131 |
|
| 132 |
1. **Upload** a contract (PDF, DOCX, or TXT) or paste text directly
|
| 133 |
+
- 💡 Scanned PDFs are automatically processed with OCR
|
| 134 |
2. Click **Analyze Contract**
|
| 135 |
3. View results across tabs:
|
| 136 |
- **Document**: Full text with inline entity highlights
|
|
|
|
| 139 |
- **Contradictions**: Conflicting clauses and missing provisions
|
| 140 |
- **Obligations**: Action items categorized by type
|
| 141 |
- **Compliance**: Regulatory framework checks
|
| 142 |
+
- **Redlining**: ✏️ Safer clause alternatives with legal citations
|
| 143 |
4. **Export** JSON/CSV reports
|
| 144 |
+
5. Switch to **💬 Contract Q&A** tab to ask questions about your contract
|
| 145 |
|
| 146 |
## 🔀 Compare Contracts
|
| 147 |
|
|
|
|
| 149 |
- Upload or paste two contracts side-by-side
|
| 150 |
- See clause-level diffs (added, removed, modified)
|
| 151 |
- Get an alignment score and risk delta
|
|
|
|
| 152 |
|
| 153 |
## ⚠️ Disclaimer
|
| 154 |
|
|
|
|
| 160 |
- [Clause Classifier Model](https://huggingface.co/Mokshith31/legalbert-contract-clause-classification)
|
| 161 |
- [Legal-BERT Base](https://huggingface.co/nlpaueb/legal-bert-base-uncased)
|
| 162 |
- [CUAD Dataset](https://huggingface.co/datasets/theatticusproject/cuad-qa)
|
| 163 |
+
- [Qwen2.5-7B (Chatbot LLM)](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct)
|
| 164 |
+
- [docTR OCR](https://github.com/mindee/doctr)
|
| 165 |
- [CUAD Paper (arXiv:2103.06268)](https://arxiv.org/abs/2103.06268)
|
| 166 |
|
| 167 |
---
|
api/Dockerfile
CHANGED
|
@@ -2,10 +2,16 @@ FROM python:3.12-slim
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
-
|
|
|
|
| 6 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
EXPOSE 8000
|
| 11 |
|
|
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
+
# Install api dependencies
|
| 6 |
+
COPY api/requirements.txt ./requirements.txt
|
| 7 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 8 |
|
| 9 |
+
# Copy shared modules from root (needed by api/main.py)
|
| 10 |
+
COPY app.py compare.py compliance.py obligations.py ./
|
| 11 |
+
COPY ocr_engine.py chatbot.py redlining.py ./
|
| 12 |
+
|
| 13 |
+
# Copy api files
|
| 14 |
+
COPY api/ ./
|
| 15 |
|
| 16 |
EXPOSE 8000
|
| 17 |
|
api/main.py
CHANGED
|
@@ -1,19 +1,19 @@
|
|
| 1 |
"""
|
| 2 |
-
ClauseGuard — FastAPI Backend
|
| 3 |
══════════════════════════════════
|
| 4 |
-
|
| 5 |
-
•
|
| 6 |
-
•
|
| 7 |
-
•
|
| 8 |
-
•
|
| 9 |
-
• Fixed CORS (removed wildcard)
|
| 10 |
-
• Added proper error responses
|
| 11 |
"""
|
| 12 |
|
| 13 |
import os
|
| 14 |
import re
|
| 15 |
import json
|
| 16 |
import time
|
|
|
|
|
|
|
| 17 |
from contextlib import asynccontextmanager
|
| 18 |
from typing import Optional
|
| 19 |
from collections import defaultdict
|
|
@@ -21,14 +21,14 @@ from datetime import datetime
|
|
| 21 |
|
| 22 |
import httpx
|
| 23 |
import numpy as np
|
| 24 |
-
from fastapi import FastAPI, HTTPException, Depends, Body, Request
|
| 25 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 26 |
from pydantic import BaseModel, Field
|
| 27 |
|
| 28 |
from auth import get_current_user, require_auth
|
| 29 |
|
| 30 |
# ── Import shared modules ──
|
| 31 |
-
# When deployed, these must be in the same directory or on PYTHONPATH
|
| 32 |
import sys
|
| 33 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 34 |
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
@@ -36,29 +36,32 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
| 36 |
try:
|
| 37 |
from app import (
|
| 38 |
split_clauses, classify_cuad, extract_entities,
|
| 39 |
-
detect_contradictions, compute_risk_score,
|
| 40 |
CUAD_LABELS, RISK_MAP, DESC_MAP, _model_status,
|
| 41 |
cuad_model, cuad_tokenizer
|
| 42 |
)
|
| 43 |
from obligations import extract_obligations
|
| 44 |
from compliance import check_compliance
|
| 45 |
from compare import compare_contracts
|
|
|
|
|
|
|
|
|
|
| 46 |
_SHARED_MODULES = True
|
| 47 |
-
except ImportError:
|
| 48 |
_SHARED_MODULES = False
|
| 49 |
-
print("[API] WARNING: Could not import shared modules
|
| 50 |
|
| 51 |
# ─── Config ───
|
| 52 |
SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
|
| 53 |
SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "")
|
| 54 |
HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "")
|
| 55 |
SAULLM_ENDPOINT = os.environ.get("SAULLM_ENDPOINT", "")
|
| 56 |
-
MAX_TEXT_LENGTH = int(os.environ.get("MAX_TEXT_LENGTH", "100000"))
|
| 57 |
|
| 58 |
# ─── Rate Limiting ───
|
| 59 |
-
_rate_limits = {}
|
| 60 |
RATE_LIMIT_REQUESTS = 30
|
| 61 |
-
RATE_LIMIT_WINDOW = 60
|
| 62 |
|
| 63 |
def _check_rate_limit(client_ip: str) -> bool:
|
| 64 |
now = time.time()
|
|
@@ -113,25 +116,16 @@ async def supabase_query(table: str, params: dict, headers_extra: dict = {}):
|
|
| 113 |
except Exception:
|
| 114 |
return []
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
# ─── Request/Response Models ───
|
| 117 |
class AnalyzeRequest(BaseModel):
|
| 118 |
text: Optional[str] = Field(None, min_length=50)
|
| 119 |
-
clauses: Optional[list] = None
|
| 120 |
source_url: Optional[str] = None
|
| 121 |
|
| 122 |
-
class AnalyzeResponse(BaseModel):
|
| 123 |
-
risk_score: int
|
| 124 |
-
grade: str
|
| 125 |
-
total_clauses: int
|
| 126 |
-
flagged_count: int
|
| 127 |
-
results: list[dict]
|
| 128 |
-
entities: list[dict]
|
| 129 |
-
contradictions: list[dict]
|
| 130 |
-
obligations: list[dict]
|
| 131 |
-
compliance: dict
|
| 132 |
-
model: str
|
| 133 |
-
latency_ms: int
|
| 134 |
-
|
| 135 |
class CompareRequest(BaseModel):
|
| 136 |
text_a: str = Field(..., min_length=50)
|
| 137 |
text_b: str = Field(..., min_length=50)
|
|
@@ -147,21 +141,28 @@ class ExplainResponse(BaseModel):
|
|
| 147 |
legal_basis: str
|
| 148 |
recommendation: str
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# ─── App ───
|
| 151 |
@asynccontextmanager
|
| 152 |
async def lifespan(app: FastAPI):
|
| 153 |
-
# Models are loaded when app.py is imported
|
| 154 |
yield
|
| 155 |
|
| 156 |
-
app = FastAPI(title="ClauseGuard API", version="
|
| 157 |
|
| 158 |
-
# FIXED: No wildcard CORS
|
| 159 |
ALLOWED_ORIGINS = [
|
| 160 |
"https://clauseguardweb.netlify.app",
|
| 161 |
"http://localhost:3000",
|
| 162 |
"http://localhost:3001",
|
| 163 |
]
|
| 164 |
-
# Allow chrome extensions
|
| 165 |
app.add_middleware(
|
| 166 |
CORSMiddleware,
|
| 167 |
allow_origins=ALLOWED_ORIGINS,
|
|
@@ -174,36 +175,36 @@ app.add_middleware(
|
|
| 174 |
@app.get("/health")
|
| 175 |
async def health():
|
| 176 |
model_status = "ml" if _SHARED_MODULES and cuad_model else "regex"
|
|
|
|
| 177 |
return {
|
| 178 |
"status": "ok",
|
| 179 |
"model": model_status,
|
| 180 |
-
"version": "
|
| 181 |
"shared_modules": _SHARED_MODULES,
|
|
|
|
|
|
|
| 182 |
}
|
| 183 |
|
| 184 |
-
@app.post("/api/analyze"
|
| 185 |
async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] = Depends(get_current_user)):
|
| 186 |
-
# Rate limiting
|
| 187 |
client_ip = request.client.host if request.client else "unknown"
|
| 188 |
if not _check_rate_limit(client_ip):
|
| 189 |
-
raise HTTPException(status_code=429, detail="Rate limit exceeded.
|
| 190 |
|
| 191 |
-
# FIXED: Accept either text or clauses from extension
|
| 192 |
text = req.text
|
| 193 |
if not text and req.clauses:
|
| 194 |
text = "\n\n".join(req.clauses) if isinstance(req.clauses, list) else str(req.clauses)
|
| 195 |
|
| 196 |
if not text or len(text.strip()) < 50:
|
| 197 |
raise HTTPException(status_code=400, detail="Text too short (minimum 50 characters)")
|
| 198 |
-
|
| 199 |
-
# Max length check
|
| 200 |
if len(text) > MAX_TEXT_LENGTH:
|
| 201 |
-
raise HTTPException(status_code=400, detail=f"Text too long (
|
| 202 |
|
| 203 |
start = time.time()
|
|
|
|
| 204 |
clauses = split_clauses(text)
|
| 205 |
if not clauses:
|
| 206 |
-
raise HTTPException(status_code=400, detail="No clauses detected
|
| 207 |
|
| 208 |
clause_results = []
|
| 209 |
for clause in clauses:
|
|
@@ -224,6 +225,15 @@ async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] =
|
|
| 224 |
risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses))
|
| 225 |
obligations = extract_obligations(text)
|
| 226 |
compliance = check_compliance(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
latency = int((time.time() - start) * 1000)
|
| 228 |
|
| 229 |
results_for_db = []
|
|
@@ -238,6 +248,29 @@ async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] =
|
|
| 238 |
}],
|
| 239 |
})
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
if user:
|
| 242 |
await supabase_insert("analyses", {
|
| 243 |
"user_id": user["id"],
|
|
@@ -253,46 +286,120 @@ async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] =
|
|
| 253 |
"compliance": compliance,
|
| 254 |
})
|
| 255 |
|
| 256 |
-
return
|
| 257 |
-
risk_score
|
| 258 |
-
grade
|
| 259 |
-
total_clauses
|
| 260 |
-
flagged_count
|
| 261 |
-
results
|
| 262 |
-
entities
|
| 263 |
-
contradictions
|
| 264 |
-
obligations
|
| 265 |
-
compliance
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
|
|
|
|
|
|
| 269 |
|
| 270 |
@app.post("/api/compare")
|
| 271 |
async def compare(req: CompareRequest, request: Request):
|
| 272 |
client_ip = request.client.host if request.client else "unknown"
|
| 273 |
if not _check_rate_limit(client_ip):
|
| 274 |
raise HTTPException(status_code=429, detail="Rate limit exceeded.")
|
| 275 |
-
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
@app.post("/api/explain", response_model=ExplainResponse)
|
| 279 |
async def explain(req: ExplainRequest, user: dict = Depends(require_auth)):
|
| 280 |
desc = DESC_MAP.get(req.category, "Unknown category.")
|
| 281 |
legal = "Consult local consumer protection laws."
|
| 282 |
-
recommendation = "Review this clause carefully.
|
| 283 |
|
| 284 |
if SAULLM_ENDPOINT and HF_API_TOKEN:
|
| 285 |
try:
|
| 286 |
prompt = (
|
| 287 |
-
f"
|
| 288 |
-
f"
|
| 289 |
-
f"
|
| 290 |
-
f"Category: {req.category}\n\n"
|
| 291 |
-
f"Provide:\n"
|
| 292 |
-
f"1. A plain-English explanation of what this clause means\n"
|
| 293 |
-
f"2. The specific legal basis or consumer protection concern\n"
|
| 294 |
-
f"3. A practical recommendation\n\n"
|
| 295 |
-
f"Be concise. 3-4 sentences per section."
|
| 296 |
)
|
| 297 |
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 298 |
resp = await client.post(
|
|
@@ -311,27 +418,16 @@ async def explain(req: ExplainRequest, user: dict = Depends(require_auth)):
|
|
| 311 |
except Exception:
|
| 312 |
pass
|
| 313 |
|
| 314 |
-
return ExplainResponse(
|
| 315 |
-
|
| 316 |
-
category=req.category,
|
| 317 |
-
explanation=desc,
|
| 318 |
-
legal_basis=legal,
|
| 319 |
-
recommendation=recommendation,
|
| 320 |
-
)
|
| 321 |
|
| 322 |
@app.get("/api/history")
|
| 323 |
async def history(user: dict = Depends(require_auth), limit: int = 20, offset: int = 0):
|
| 324 |
limit = min(limit, 100)
|
| 325 |
-
data = await supabase_query(
|
| 326 |
-
"
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
"select": "*",
|
| 330 |
-
"order": "created_at.desc",
|
| 331 |
-
"limit": str(limit),
|
| 332 |
-
"offset": str(offset),
|
| 333 |
-
},
|
| 334 |
-
)
|
| 335 |
return {"analyses": data, "limit": limit, "offset": offset}
|
| 336 |
|
| 337 |
if __name__ == "__main__":
|
|
|
|
| 1 |
"""
|
| 2 |
+
ClauseGuard — FastAPI Backend v4.0
|
| 3 |
══════════════════════════════════
|
| 4 |
+
New in v4.0:
|
| 5 |
+
• /api/redline — clause redlining suggestions
|
| 6 |
+
• /api/chat — RAG chatbot (streaming)
|
| 7 |
+
• /api/ocr — OCR scanned PDF extraction
|
| 8 |
+
• Updated analysis to include redlining data
|
|
|
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
import os
|
| 12 |
import re
|
| 13 |
import json
|
| 14 |
import time
|
| 15 |
+
import uuid
|
| 16 |
+
import tempfile
|
| 17 |
from contextlib import asynccontextmanager
|
| 18 |
from typing import Optional
|
| 19 |
from collections import defaultdict
|
|
|
|
| 21 |
|
| 22 |
import httpx
|
| 23 |
import numpy as np
|
| 24 |
+
from fastapi import FastAPI, HTTPException, Depends, Body, Request, UploadFile, File as FastAPIFile
|
| 25 |
from fastapi.middleware.cors import CORSMiddleware
|
| 26 |
+
from fastapi.responses import StreamingResponse
|
| 27 |
from pydantic import BaseModel, Field
|
| 28 |
|
| 29 |
from auth import get_current_user, require_auth
|
| 30 |
|
| 31 |
# ── Import shared modules ──
|
|
|
|
| 32 |
import sys
|
| 33 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 34 |
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
| 36 |
try:
|
| 37 |
from app import (
|
| 38 |
split_clauses, classify_cuad, extract_entities,
|
| 39 |
+
detect_contradictions, compute_risk_score, analyze_contract,
|
| 40 |
CUAD_LABELS, RISK_MAP, DESC_MAP, _model_status,
|
| 41 |
cuad_model, cuad_tokenizer
|
| 42 |
)
|
| 43 |
from obligations import extract_obligations
|
| 44 |
from compliance import check_compliance
|
| 45 |
from compare import compare_contracts
|
| 46 |
+
from redlining import generate_redlines
|
| 47 |
+
from chatbot import index_contract, chat_respond
|
| 48 |
+
from ocr_engine import parse_pdf_smart, get_ocr_status
|
| 49 |
_SHARED_MODULES = True
|
| 50 |
+
except ImportError as e:
|
| 51 |
_SHARED_MODULES = False
|
| 52 |
+
print(f"[API] WARNING: Could not import shared modules: {e}")
|
| 53 |
|
| 54 |
# ─── Config ───
|
| 55 |
SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
|
| 56 |
SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "")
|
| 57 |
HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "")
|
| 58 |
SAULLM_ENDPOINT = os.environ.get("SAULLM_ENDPOINT", "")
|
| 59 |
+
MAX_TEXT_LENGTH = int(os.environ.get("MAX_TEXT_LENGTH", "100000"))
|
| 60 |
|
| 61 |
# ─── Rate Limiting ───
|
| 62 |
+
_rate_limits = {}
|
| 63 |
RATE_LIMIT_REQUESTS = 30
|
| 64 |
+
RATE_LIMIT_WINDOW = 60
|
| 65 |
|
| 66 |
def _check_rate_limit(client_ip: str) -> bool:
|
| 67 |
now = time.time()
|
|
|
|
| 116 |
except Exception:
|
| 117 |
return []
|
| 118 |
|
| 119 |
+
# ─── In-memory RAG session store ───
|
| 120 |
+
_rag_sessions: dict = {}
|
| 121 |
+
_RAG_SESSION_MAX = 100
|
| 122 |
+
|
| 123 |
# ─── Request/Response Models ───
|
| 124 |
class AnalyzeRequest(BaseModel):
|
| 125 |
text: Optional[str] = Field(None, min_length=50)
|
| 126 |
+
clauses: Optional[list] = None
|
| 127 |
source_url: Optional[str] = None
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
class CompareRequest(BaseModel):
|
| 130 |
text_a: str = Field(..., min_length=50)
|
| 131 |
text_b: str = Field(..., min_length=50)
|
|
|
|
| 141 |
legal_basis: str
|
| 142 |
recommendation: str
|
| 143 |
|
| 144 |
+
class ChatRequest(BaseModel):
|
| 145 |
+
message: str = Field(..., min_length=1, max_length=2000)
|
| 146 |
+
session_id: str
|
| 147 |
+
history: Optional[list[dict]] = None
|
| 148 |
+
|
| 149 |
+
class RedlineRequest(BaseModel):
|
| 150 |
+
session_id: Optional[str] = None
|
| 151 |
+
text: Optional[str] = None
|
| 152 |
+
use_llm: bool = True
|
| 153 |
+
|
| 154 |
# ─── App ───
|
| 155 |
@asynccontextmanager
|
| 156 |
async def lifespan(app: FastAPI):
|
|
|
|
| 157 |
yield
|
| 158 |
|
| 159 |
+
app = FastAPI(title="ClauseGuard API", version="4.0.0", lifespan=lifespan)
|
| 160 |
|
|
|
|
| 161 |
ALLOWED_ORIGINS = [
|
| 162 |
"https://clauseguardweb.netlify.app",
|
| 163 |
"http://localhost:3000",
|
| 164 |
"http://localhost:3001",
|
| 165 |
]
|
|
|
|
| 166 |
app.add_middleware(
|
| 167 |
CORSMiddleware,
|
| 168 |
allow_origins=ALLOWED_ORIGINS,
|
|
|
|
| 175 |
@app.get("/health")
|
| 176 |
async def health():
|
| 177 |
model_status = "ml" if _SHARED_MODULES and cuad_model else "regex"
|
| 178 |
+
ocr_status = get_ocr_status() if _SHARED_MODULES else "unavailable"
|
| 179 |
return {
|
| 180 |
"status": "ok",
|
| 181 |
"model": model_status,
|
| 182 |
+
"version": "4.0.0",
|
| 183 |
"shared_modules": _SHARED_MODULES,
|
| 184 |
+
"ocr": ocr_status,
|
| 185 |
+
"features": ["analyze", "compare", "redline", "chat", "ocr"],
|
| 186 |
}
|
| 187 |
|
| 188 |
+
@app.post("/api/analyze")
|
| 189 |
async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] = Depends(get_current_user)):
|
|
|
|
| 190 |
client_ip = request.client.host if request.client else "unknown"
|
| 191 |
if not _check_rate_limit(client_ip):
|
| 192 |
+
raise HTTPException(status_code=429, detail="Rate limit exceeded.")
|
| 193 |
|
|
|
|
| 194 |
text = req.text
|
| 195 |
if not text and req.clauses:
|
| 196 |
text = "\n\n".join(req.clauses) if isinstance(req.clauses, list) else str(req.clauses)
|
| 197 |
|
| 198 |
if not text or len(text.strip()) < 50:
|
| 199 |
raise HTTPException(status_code=400, detail="Text too short (minimum 50 characters)")
|
|
|
|
|
|
|
| 200 |
if len(text) > MAX_TEXT_LENGTH:
|
| 201 |
+
raise HTTPException(status_code=400, detail=f"Text too long (max {MAX_TEXT_LENGTH} chars)")
|
| 202 |
|
| 203 |
start = time.time()
|
| 204 |
+
|
| 205 |
clauses = split_clauses(text)
|
| 206 |
if not clauses:
|
| 207 |
+
raise HTTPException(status_code=400, detail="No clauses detected")
|
| 208 |
|
| 209 |
clause_results = []
|
| 210 |
for clause in clauses:
|
|
|
|
| 225 |
risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses))
|
| 226 |
obligations = extract_obligations(text)
|
| 227 |
compliance = check_compliance(text)
|
| 228 |
+
|
| 229 |
+
# v4.0: Redlining
|
| 230 |
+
analysis_for_redline = {"clauses": clause_results}
|
| 231 |
+
redlines = []
|
| 232 |
+
try:
|
| 233 |
+
redlines = generate_redlines(analysis_for_redline, use_llm=True)
|
| 234 |
+
except Exception as e:
|
| 235 |
+
print(f"[API] Redlining error: {e}")
|
| 236 |
+
|
| 237 |
latency = int((time.time() - start) * 1000)
|
| 238 |
|
| 239 |
results_for_db = []
|
|
|
|
| 248 |
}],
|
| 249 |
})
|
| 250 |
|
| 251 |
+
# v4.0: RAG indexing
|
| 252 |
+
session_id = None
|
| 253 |
+
try:
|
| 254 |
+
chunks, embeddings, _status = index_contract(text)
|
| 255 |
+
if chunks and embeddings is not None:
|
| 256 |
+
session_id = uuid.uuid4().hex[:12]
|
| 257 |
+
if len(_rag_sessions) >= _RAG_SESSION_MAX:
|
| 258 |
+
oldest = next(iter(_rag_sessions))
|
| 259 |
+
del _rag_sessions[oldest]
|
| 260 |
+
_rag_sessions[session_id] = {
|
| 261 |
+
"chunks": chunks,
|
| 262 |
+
"embeddings": embeddings,
|
| 263 |
+
"analysis": {
|
| 264 |
+
"risk": {"score": risk, "grade": grade, "breakdown": sev_counts},
|
| 265 |
+
"metadata": {"total_clauses": len(clauses), "flagged_clauses": len(clause_results)},
|
| 266 |
+
"clauses": clause_results[:30],
|
| 267 |
+
"entities": entities[:30],
|
| 268 |
+
"contradictions": contradictions,
|
| 269 |
+
},
|
| 270 |
+
}
|
| 271 |
+
except Exception as e:
|
| 272 |
+
print(f"[API] RAG indexing error: {e}")
|
| 273 |
+
|
| 274 |
if user:
|
| 275 |
await supabase_insert("analyses", {
|
| 276 |
"user_id": user["id"],
|
|
|
|
| 286 |
"compliance": compliance,
|
| 287 |
})
|
| 288 |
|
| 289 |
+
return {
|
| 290 |
+
"risk_score": risk,
|
| 291 |
+
"grade": grade,
|
| 292 |
+
"total_clauses": len(clauses),
|
| 293 |
+
"flagged_count": len(set(cr["text"] for cr in clause_results)),
|
| 294 |
+
"results": results_for_db,
|
| 295 |
+
"entities": entities,
|
| 296 |
+
"contradictions": contradictions,
|
| 297 |
+
"obligations": obligations,
|
| 298 |
+
"compliance": compliance,
|
| 299 |
+
"redlines": redlines,
|
| 300 |
+
"model": "ml" if cuad_model else "regex",
|
| 301 |
+
"latency_ms": latency,
|
| 302 |
+
"session_id": session_id,
|
| 303 |
+
}
|
| 304 |
|
| 305 |
@app.post("/api/compare")
|
| 306 |
async def compare(req: CompareRequest, request: Request):
|
| 307 |
client_ip = request.client.host if request.client else "unknown"
|
| 308 |
if not _check_rate_limit(client_ip):
|
| 309 |
raise HTTPException(status_code=429, detail="Rate limit exceeded.")
|
| 310 |
+
return compare_contracts(req.text_a, req.text_b)
|
| 311 |
+
|
| 312 |
+
@app.post("/api/redline")
|
| 313 |
+
async def redline(req: RedlineRequest, request: Request):
|
| 314 |
+
client_ip = request.client.host if request.client else "unknown"
|
| 315 |
+
if not _check_rate_limit(client_ip):
|
| 316 |
+
raise HTTPException(status_code=429, detail="Rate limit exceeded.")
|
| 317 |
+
|
| 318 |
+
if req.session_id and req.session_id in _rag_sessions:
|
| 319 |
+
analysis = _rag_sessions[req.session_id]["analysis"]
|
| 320 |
+
elif req.text:
|
| 321 |
+
result, error = analyze_contract(req.text)
|
| 322 |
+
if error:
|
| 323 |
+
raise HTTPException(status_code=400, detail=error)
|
| 324 |
+
analysis = result
|
| 325 |
+
else:
|
| 326 |
+
raise HTTPException(status_code=400, detail="Provide session_id or text")
|
| 327 |
+
|
| 328 |
+
redlines = generate_redlines(analysis, use_llm=req.use_llm)
|
| 329 |
+
return {"redlines": redlines, "count": len(redlines)}
|
| 330 |
+
|
| 331 |
+
@app.post("/api/chat")
|
| 332 |
+
async def chat(req: ChatRequest, request: Request):
|
| 333 |
+
client_ip = request.client.host if request.client else "unknown"
|
| 334 |
+
if not _check_rate_limit(client_ip):
|
| 335 |
+
raise HTTPException(status_code=429, detail="Rate limit exceeded.")
|
| 336 |
+
|
| 337 |
+
if req.session_id not in _rag_sessions:
|
| 338 |
+
raise HTTPException(status_code=404, detail="Session not found. Analyze a contract first.")
|
| 339 |
+
|
| 340 |
+
session = _rag_sessions[req.session_id]
|
| 341 |
+
response_text = ""
|
| 342 |
+
for partial in chat_respond(req.message, req.history or [],
|
| 343 |
+
session["chunks"], session["embeddings"], session["analysis"]):
|
| 344 |
+
response_text = partial
|
| 345 |
+
|
| 346 |
+
return {"response": response_text, "session_id": req.session_id}
|
| 347 |
+
|
| 348 |
+
@app.post("/api/chat/stream")
|
| 349 |
+
async def chat_stream(req: ChatRequest, request: Request):
|
| 350 |
+
client_ip = request.client.host if request.client else "unknown"
|
| 351 |
+
if not _check_rate_limit(client_ip):
|
| 352 |
+
raise HTTPException(status_code=429, detail="Rate limit exceeded.")
|
| 353 |
+
|
| 354 |
+
if req.session_id not in _rag_sessions:
|
| 355 |
+
raise HTTPException(status_code=404, detail="Session not found.")
|
| 356 |
+
|
| 357 |
+
session = _rag_sessions[req.session_id]
|
| 358 |
+
|
| 359 |
+
async def generate():
|
| 360 |
+
last = ""
|
| 361 |
+
for partial in chat_respond(
|
| 362 |
+
req.message, req.history or [],
|
| 363 |
+
session["chunks"], session["embeddings"], session["analysis"]
|
| 364 |
+
):
|
| 365 |
+
delta = partial[len(last):]
|
| 366 |
+
last = partial
|
| 367 |
+
if delta:
|
| 368 |
+
yield f"data: {json.dumps({'delta': delta})}\n\n"
|
| 369 |
+
yield "data: [DONE]\n\n"
|
| 370 |
+
|
| 371 |
+
return StreamingResponse(generate(), media_type="text/event-stream")
|
| 372 |
+
|
| 373 |
+
@app.post("/api/ocr")
|
| 374 |
+
async def ocr_endpoint(file: UploadFile = FastAPIFile(...)):
|
| 375 |
+
if not file.filename or not file.filename.lower().endswith(".pdf"):
|
| 376 |
+
raise HTTPException(status_code=400, detail="Only PDF files supported")
|
| 377 |
+
|
| 378 |
+
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
|
| 379 |
+
content = await file.read()
|
| 380 |
+
tmp.write(content)
|
| 381 |
+
tmp_path = tmp.name
|
| 382 |
+
|
| 383 |
+
try:
|
| 384 |
+
text, error, method = parse_pdf_smart(tmp_path)
|
| 385 |
+
if error:
|
| 386 |
+
raise HTTPException(status_code=400, detail=error)
|
| 387 |
+
return {"text": text, "method": method, "chars": len(text) if text else 0, "filename": file.filename}
|
| 388 |
+
finally:
|
| 389 |
+
os.unlink(tmp_path)
|
| 390 |
|
| 391 |
@app.post("/api/explain", response_model=ExplainResponse)
|
| 392 |
async def explain(req: ExplainRequest, user: dict = Depends(require_auth)):
|
| 393 |
desc = DESC_MAP.get(req.category, "Unknown category.")
|
| 394 |
legal = "Consult local consumer protection laws."
|
| 395 |
+
recommendation = "Review this clause carefully."
|
| 396 |
|
| 397 |
if SAULLM_ENDPOINT and HF_API_TOKEN:
|
| 398 |
try:
|
| 399 |
prompt = (
|
| 400 |
+
f"Analyze this contract clause and explain why it may be risky.\n\n"
|
| 401 |
+
f"Clause: \"{req.clause}\"\nCategory: {req.category}\n\n"
|
| 402 |
+
f"Provide: 1) Plain-English explanation 2) Legal basis 3) Recommendation"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
)
|
| 404 |
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 405 |
resp = await client.post(
|
|
|
|
| 418 |
except Exception:
|
| 419 |
pass
|
| 420 |
|
| 421 |
+
return ExplainResponse(clause=req.clause, category=req.category,
|
| 422 |
+
explanation=desc, legal_basis=legal, recommendation=recommendation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
@app.get("/api/history")
|
| 425 |
async def history(user: dict = Depends(require_auth), limit: int = 20, offset: int = 0):
|
| 426 |
limit = min(limit, 100)
|
| 427 |
+
data = await supabase_query("analyses", {
|
| 428 |
+
"user_id": f"eq.{user['id']}", "select": "*",
|
| 429 |
+
"order": "created_at.desc", "limit": str(limit), "offset": str(offset),
|
| 430 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
return {"analyses": data, "limit": limit, "offset": offset}
|
| 432 |
|
| 433 |
if __name__ == "__main__":
|
api/requirements.txt
CHANGED
|
@@ -1,10 +1,13 @@
|
|
| 1 |
-
fastapi>=0.
|
| 2 |
-
uvicorn[standard]>=0.
|
| 3 |
-
pydantic>=2.
|
| 4 |
-
transformers>=
|
| 5 |
numpy>=2.0.0
|
| 6 |
python-jose[cryptography]>=3.3.0
|
| 7 |
httpx>=0.28.0
|
| 8 |
peft>=0.15.0
|
| 9 |
torch>=2.5.0
|
| 10 |
sentence-transformers>=3.0.0
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.115.0
|
| 2 |
+
uvicorn[standard]>=0.34.0
|
| 3 |
+
pydantic>=2.10.0
|
| 4 |
+
transformers>=4.45.0
|
| 5 |
numpy>=2.0.0
|
| 6 |
python-jose[cryptography]>=3.3.0
|
| 7 |
httpx>=0.28.0
|
| 8 |
peft>=0.15.0
|
| 9 |
torch>=2.5.0
|
| 10 |
sentence-transformers>=3.0.0
|
| 11 |
+
python-doctr[torch]>=0.9.0
|
| 12 |
+
huggingface_hub>=0.25.0
|
| 13 |
+
python-multipart>=0.0.7
|
app.py
CHANGED
|
@@ -1,7 +1,12 @@
|
|
| 1 |
"""
|
| 2 |
-
ClauseGuard — World's Best Legal Contract Analysis Tool (
|
| 3 |
═══════════════════════════════════════════════════════════════
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
• Fixed CUAD label mapping (added missing index 6: "Notice Period to Terminate Renewal")
|
| 6 |
• Switched from softmax → sigmoid for proper multi-label classification
|
| 7 |
• Per-class optimized thresholds instead of flat 0.15
|
|
@@ -21,6 +26,9 @@ Models:
|
|
| 21 |
(LoRA adapter on nlpaueb/legal-bert-base-uncased, 41 CUAD classes)
|
| 22 |
• Legal NER: matterstack/legal-bert-ner (token classification)
|
| 23 |
• NLI: cross-encoder/nli-deberta-v3-base (contradiction detection)
|
|
|
|
|
|
|
|
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
import os
|
|
@@ -71,6 +79,9 @@ except Exception:
|
|
| 71 |
from compare import compare_contracts, render_comparison_html
|
| 72 |
from obligations import extract_obligations, render_obligations_html
|
| 73 |
from compliance import check_compliance, render_compliance_html
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# ═══════════════════════════════════════════════════════════════════════
|
| 76 |
# 1. CONFIGURATION — FIXED label mapping (41 labels, index 6 restored)
|
|
@@ -335,20 +346,15 @@ _load_nli_model()
|
|
| 335 |
# ═══════════════════════════════════════════════════════════════════════
|
| 336 |
|
| 337 |
def parse_pdf(file_path):
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
if not text.strip():
|
| 348 |
-
return None, "PDF appears to be scanned/image-based. OCR is not yet supported. Please use a digital PDF or paste text directly."
|
| 349 |
-
return text.strip(), None
|
| 350 |
-
except Exception as e:
|
| 351 |
-
return None, f"PDF parse error: {e}"
|
| 352 |
|
| 353 |
def parse_docx(file_path):
|
| 354 |
if not _HAS_DOCX:
|
|
@@ -378,11 +384,22 @@ def parse_document(file_path):
|
|
| 378 |
return None, f"Unsupported file type: {ext}"
|
| 379 |
|
| 380 |
# ═══════════════════════════════════════════════════════════════════════
|
| 381 |
-
# 4.
|
| 382 |
# ═══════════════════════════════════════════════════════════════════════
|
| 383 |
|
|
|
|
|
|
|
|
|
|
| 384 |
def split_clauses(text):
|
| 385 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
text = re.sub(r'\n{3,}', '\n\n', text.strip())
|
| 387 |
|
| 388 |
# First try to detect numbered sections (1., 2., 3.1, (a), etc.)
|
|
@@ -426,9 +443,13 @@ def split_clauses(text):
|
|
| 426 |
preamble = text[:positions[0]].strip()
|
| 427 |
if len(preamble) > 30:
|
| 428 |
clauses.insert(0, preamble)
|
| 429 |
-
|
|
|
|
|
|
|
| 430 |
else:
|
| 431 |
-
|
|
|
|
|
|
|
| 432 |
|
| 433 |
def _fallback_split(text):
|
| 434 |
"""Fallback: split on paragraph breaks and sentence boundaries."""
|
|
@@ -462,8 +483,40 @@ def _fallback_split(text):
|
|
| 462 |
|
| 463 |
# ═══════════════════════════════════════════════════════════════════════
|
| 464 |
# 5. CLAUSE DETECTION — FIXED: sigmoid + per-class thresholds + caching
|
|
|
|
|
|
|
| 465 |
# ═══════════════════════════════════════════════════════════════════════
|
| 466 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
def _text_hash(text):
|
| 468 |
return hashlib.md5(text.encode()).hexdigest()
|
| 469 |
|
|
@@ -474,14 +527,17 @@ def classify_cuad(clause_text):
|
|
| 474 |
if cuad_model is None or cuad_tokenizer is None:
|
| 475 |
return _classify_regex(clause_text)
|
| 476 |
|
|
|
|
|
|
|
|
|
|
| 477 |
# Check cache
|
| 478 |
-
h = _text_hash(
|
| 479 |
if h in _prediction_cache:
|
| 480 |
return _prediction_cache[h]
|
| 481 |
|
| 482 |
try:
|
| 483 |
inputs = cuad_tokenizer(
|
| 484 |
-
|
| 485 |
return_tensors="pt",
|
| 486 |
truncation=True,
|
| 487 |
max_length=256,
|
|
@@ -498,10 +554,15 @@ def classify_cuad(clause_text):
|
|
| 498 |
threshold = _CUAD_THRESHOLDS.get(i, 0.40)
|
| 499 |
if float(prob) > threshold and i < len(CUAD_LABELS):
|
| 500 |
label = CUAD_LABELS[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
risk = RISK_MAP.get(label, "LOW")
|
| 502 |
results.append({
|
| 503 |
"label": label,
|
| 504 |
-
"confidence": round(
|
| 505 |
"risk": risk,
|
| 506 |
"description": DESC_MAP.get(label, label),
|
| 507 |
"source": "ml",
|
|
@@ -773,19 +834,33 @@ def detect_contradictions(clause_results, raw_text=""):
|
|
| 773 |
"source": "heuristic",
|
| 774 |
})
|
| 775 |
|
| 776 |
-
# ── 2. Missing critical clauses ──
|
| 777 |
-
|
| 778 |
-
"Governing Law":
|
| 779 |
-
|
| 780 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
}
|
| 782 |
-
for
|
| 783 |
-
|
|
|
|
| 784 |
contradictions.append({
|
| 785 |
"type": "MISSING",
|
| 786 |
-
"explanation":
|
| 787 |
"severity": "MEDIUM",
|
| 788 |
-
"clauses": [
|
| 789 |
"source": "structural",
|
| 790 |
})
|
| 791 |
|
|
@@ -847,13 +922,21 @@ def analyze_contract(text):
|
|
| 847 |
contradictions = detect_contradictions(clause_results, text)
|
| 848 |
risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses))
|
| 849 |
obligations = extract_obligations(text)
|
|
|
|
| 850 |
compliance = check_compliance(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 851 |
result = {
|
| 852 |
"metadata": {
|
| 853 |
"analysis_date": datetime.now().isoformat(),
|
| 854 |
"total_clauses": len(clauses),
|
| 855 |
-
"flagged_clauses":
|
|
|
|
| 856 |
"model": get_model_status_text(),
|
|
|
|
| 857 |
},
|
| 858 |
"risk": {
|
| 859 |
"score": risk,
|
|
@@ -1119,11 +1202,11 @@ def process_upload(file):
|
|
| 1119 |
def run_analysis(text):
|
| 1120 |
if not text or len(text.strip()) < 50:
|
| 1121 |
err_html = '<p style="color:#dc2626;padding:16px;">Document too short (minimum 50 characters)</p>'
|
| 1122 |
-
return [err_html] *
|
| 1123 |
result, error = analyze_contract(text)
|
| 1124 |
if error:
|
| 1125 |
err_html = f'<p style="color:#dc2626;padding:16px;">{error}</p>'
|
| 1126 |
-
return [err_html] *
|
| 1127 |
|
| 1128 |
# FIXED: per-session temp files
|
| 1129 |
session_id = uuid.uuid4().hex[:8]
|
|
@@ -1136,6 +1219,10 @@ def run_analysis(text):
|
|
| 1136 |
with open(csv_path, "w") as f:
|
| 1137 |
f.write(csv_content)
|
| 1138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1139 |
return [
|
| 1140 |
render_summary(result),
|
| 1141 |
render_clause_cards(result),
|
|
@@ -1144,13 +1231,15 @@ def run_analysis(text):
|
|
| 1144 |
render_document_viewer(result),
|
| 1145 |
render_obligations_html(result.get("obligations", [])),
|
| 1146 |
render_compliance_html(result.get("compliance", {})),
|
|
|
|
| 1147 |
json_path,
|
| 1148 |
csv_path,
|
| 1149 |
"Analysis complete",
|
|
|
|
| 1150 |
]
|
| 1151 |
|
| 1152 |
def do_clear():
|
| 1153 |
-
return [""] *
|
| 1154 |
|
| 1155 |
# ── Example contracts ──
|
| 1156 |
SPOTIFY_TOS = """By using the Spotify Service, you agree to be bound by these Terms of Use.
|
|
@@ -1234,17 +1323,22 @@ with gr.Blocks(
|
|
| 1234 |
"""
|
| 1235 |
) as demo:
|
| 1236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1237 |
gr.HTML("""
|
| 1238 |
<div style="display:flex;align-items:center;justify-content:space-between;padding:12px 0;border-bottom:2px solid #e5e7eb;margin-bottom:16px;">
|
| 1239 |
<div>
|
| 1240 |
<h1 style="font-size:24px;font-weight:700;margin:0;color:#1f2937;">🛡️ ClauseGuard</h1>
|
| 1241 |
-
<p style="font-size:13px;color:#6b7280;margin:4px 0 0 0;">AI-Powered Legal Contract Analysis · 41 Clause Categories · Risk Scoring · ML NER · NLI Contradictions · Compliance · Obligations</p>
|
| 1242 |
</div>
|
| 1243 |
-
<div style="font-size:12px;color:#9ca3af;">
|
| 1244 |
</div>
|
| 1245 |
""")
|
| 1246 |
|
| 1247 |
-
# ── Main Tabs: Analysis vs Comparison ──
|
| 1248 |
with gr.Tabs():
|
| 1249 |
|
| 1250 |
# ═══════ TAB 1: Single Contract Analysis ═══════
|
|
@@ -1261,7 +1355,7 @@ with gr.Blocks(
|
|
| 1261 |
with gr.Column(scale=3):
|
| 1262 |
text_input = gr.Textbox(
|
| 1263 |
label="📄 Contract Text",
|
| 1264 |
-
placeholder="Paste contract text here, or upload a file above...",
|
| 1265 |
lines=14,
|
| 1266 |
max_lines=40,
|
| 1267 |
show_copy_button=True,
|
|
@@ -1304,6 +1398,8 @@ with gr.Blocks(
|
|
| 1304 |
obligations_html = gr.HTML(label="Obligation Tracker")
|
| 1305 |
with gr.Tab("⚖️ Compliance"):
|
| 1306 |
compliance_html = gr.HTML(label="Compliance Checker")
|
|
|
|
|
|
|
| 1307 |
|
| 1308 |
# ═══════ TAB 2: Contract Comparison ═══════
|
| 1309 |
with gr.Tab("🔀 Compare Contracts"):
|
|
@@ -1352,6 +1448,53 @@ with gr.Blocks(
|
|
| 1352 |
with gr.Column(scale=2):
|
| 1353 |
comp_json = gr.JSON(label="Raw Comparison Data")
|
| 1354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1355 |
# ── Events ──
|
| 1356 |
def _load_file(file):
|
| 1357 |
text, err = parse_document(file) if file else ("", "No file")
|
|
@@ -1359,23 +1502,41 @@ with gr.Blocks(
|
|
| 1359 |
return "", err
|
| 1360 |
return text, "Loaded successfully" if not err else err
|
| 1361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1362 |
load_btn.click(_load_file, inputs=[file_input], outputs=[text_input, load_status])
|
| 1363 |
comp_load_a.click(_load_file, inputs=[comp_file_a], outputs=[comp_text_a, comp_status_a])
|
| 1364 |
comp_load_b.click(_load_file, inputs=[comp_file_b], outputs=[comp_text_b, comp_status_b])
|
| 1365 |
|
| 1366 |
scan_btn.click(
|
| 1367 |
-
|
| 1368 |
inputs=[text_input],
|
| 1369 |
-
outputs=[
|
| 1370 |
-
|
| 1371 |
-
|
|
|
|
|
|
|
|
|
|
| 1372 |
)
|
| 1373 |
|
| 1374 |
clear_btn.click(
|
| 1375 |
-
|
| 1376 |
-
outputs=[
|
| 1377 |
-
|
| 1378 |
-
|
|
|
|
|
|
|
|
|
|
| 1379 |
)
|
| 1380 |
|
| 1381 |
comp_btn.click(
|
|
@@ -1391,6 +1552,8 @@ with gr.Blocks(
|
|
| 1391 |
· Model: <a href="https://huggingface.co/Mokshith31/legalbert-contract-clause-classification" style="color:#6b7280;">Legal-BERT + CUAD (41 classes)</a>
|
| 1392 |
· NER: <a href="https://huggingface.co/matterstack/legal-bert-ner" style="color:#6b7280;">Legal-BERT NER</a>
|
| 1393 |
· NLI: <a href="https://huggingface.co/cross-encoder/nli-deberta-v3-base" style="color:#6b7280;">DeBERTa-v3 NLI</a>
|
|
|
|
|
|
|
| 1394 |
· Dataset: <a href="https://huggingface.co/datasets/theatticusproject/cuad-qa" style="color:#6b7280;">CUAD</a>
|
| 1395 |
· <a href="https://huggingface.co/spaces/gaurv007/ClauseGuard" style="color:#6b7280;">ClauseGuard Space</a>
|
| 1396 |
</p>
|
|
|
|
| 1 |
"""
|
| 2 |
+
ClauseGuard — World's Best Legal Contract Analysis Tool (v4.0)
|
| 3 |
═══════════════════════════════════════════════════════════════
|
| 4 |
+
New in v4.0:
|
| 5 |
+
• OCR support for scanned PDFs (docTR engine with smart native/scanned routing)
|
| 6 |
+
• Contract Q&A Chatbot (RAG: embedding retrieval + HF Inference API streaming)
|
| 7 |
+
• Clause Redlining (3-tier: template lookup + RAG + LLM refinement)
|
| 8 |
+
|
| 9 |
+
Carried from v3.0:
|
| 10 |
• Fixed CUAD label mapping (added missing index 6: "Notice Period to Terminate Renewal")
|
| 11 |
• Switched from softmax → sigmoid for proper multi-label classification
|
| 12 |
• Per-class optimized thresholds instead of flat 0.15
|
|
|
|
| 26 |
(LoRA adapter on nlpaueb/legal-bert-base-uncased, 41 CUAD classes)
|
| 27 |
• Legal NER: matterstack/legal-bert-ner (token classification)
|
| 28 |
• NLI: cross-encoder/nli-deberta-v3-base (contradiction detection)
|
| 29 |
+
• Embeddings: sentence-transformers/all-MiniLM-L6-v2 (RAG retrieval)
|
| 30 |
+
• OCR: docTR fast_base + crnn_vgg16_bn (scanned PDF extraction)
|
| 31 |
+
• LLM: Qwen/Qwen2.5-7B-Instruct via HF Inference API (chatbot + redlining)
|
| 32 |
"""
|
| 33 |
|
| 34 |
import os
|
|
|
|
| 79 |
from compare import compare_contracts, render_comparison_html
|
| 80 |
from obligations import extract_obligations, render_obligations_html
|
| 81 |
from compliance import check_compliance, render_compliance_html
|
| 82 |
+
from ocr_engine import parse_pdf_smart, get_ocr_status
|
| 83 |
+
from chatbot import index_contract, chat_respond, get_chatbot_status
|
| 84 |
+
from redlining import generate_redlines, render_redlines_html
|
| 85 |
|
| 86 |
# ═══════════════════════════════════════════════════════════════════════
|
| 87 |
# 1. CONFIGURATION — FIXED label mapping (41 labels, index 6 restored)
|
|
|
|
| 346 |
# ═══════════════════════════════════════════════════════════════════════
|
| 347 |
|
| 348 |
def parse_pdf(file_path):
|
| 349 |
+
"""Smart PDF parser: native text extraction with OCR fallback for scanned PDFs."""
|
| 350 |
+
text, error, method = parse_pdf_smart(file_path)
|
| 351 |
+
if text:
|
| 352 |
+
if method == "ocr":
|
| 353 |
+
print(f"[ClauseGuard] PDF extracted via OCR ({len(text)} chars)")
|
| 354 |
+
return text, None
|
| 355 |
+
if error:
|
| 356 |
+
return None, error
|
| 357 |
+
return None, "Could not extract text from PDF. Try uploading a clearer scan or digital PDF."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
|
| 359 |
def parse_docx(file_path):
|
| 360 |
if not _HAS_DOCX:
|
|
|
|
| 384 |
return None, f"Unsupported file type: {ext}"
|
| 385 |
|
| 386 |
# ═══════════════════════════════════════════════════════════════════════
|
| 387 |
+
# 4. DETERMINISTIC CLAUSE SPLITTING (Fix 1 from bug report)
|
| 388 |
# ═══════════════════════════════════════════════════════════════════════
|
| 389 |
|
| 390 |
+
# Document-level chunk cache: same text always produces same chunks
|
| 391 |
+
_chunk_cache = {}
|
| 392 |
+
|
| 393 |
def split_clauses(text):
|
| 394 |
+
"""Deterministic, structure-aware clause splitting.
|
| 395 |
+
Fix 1: Same input ALWAYS produces same output. Normalized text is hashed
|
| 396 |
+
and cached so repeated runs on identical documents are identical."""
|
| 397 |
+
# Normalize whitespace before hashing for determinism
|
| 398 |
+
normalized = re.sub(r'\s+', ' ', text.strip())
|
| 399 |
+
text_hash = hashlib.sha256(normalized.encode()).hexdigest()
|
| 400 |
+
if text_hash in _chunk_cache:
|
| 401 |
+
return _chunk_cache[text_hash]
|
| 402 |
+
|
| 403 |
text = re.sub(r'\n{3,}', '\n\n', text.strip())
|
| 404 |
|
| 405 |
# First try to detect numbered sections (1., 2., 3.1, (a), etc.)
|
|
|
|
| 443 |
preamble = text[:positions[0]].strip()
|
| 444 |
if len(preamble) > 30:
|
| 445 |
clauses.insert(0, preamble)
|
| 446 |
+
result = clauses if clauses else _fallback_split(text)
|
| 447 |
+
_chunk_cache[text_hash] = result
|
| 448 |
+
return result
|
| 449 |
else:
|
| 450 |
+
result = _fallback_split(text)
|
| 451 |
+
_chunk_cache[text_hash] = result
|
| 452 |
+
return result
|
| 453 |
|
| 454 |
def _fallback_split(text):
|
| 455 |
"""Fallback: split on paragraph breaks and sentence boundaries."""
|
|
|
|
| 483 |
|
| 484 |
# ═══════════════════════════════════════════════════════════════════════
|
| 485 |
# 5. CLAUSE DETECTION — FIXED: sigmoid + per-class thresholds + caching
|
| 486 |
+
# Fix 3: Strip section headings before classification
|
| 487 |
+
# Fix 6: Label guardrails for high-confidence false positives
|
| 488 |
# ═══════════════════════════════════════════════════════════════════════
|
| 489 |
|
| 490 |
+
# Fix 3: Section heading pattern — strip before classifying
|
| 491 |
+
_HEADING_RE = re.compile(r'^\d+(?:\.\d+)*\s+[A-Z][A-Z\s&,/]+$', re.MULTILINE)
|
| 492 |
+
|
| 493 |
+
def _strip_heading(text):
|
| 494 |
+
"""Remove leading section headings that confuse the classifier."""
|
| 495 |
+
lines = text.split('\n')
|
| 496 |
+
if lines and _HEADING_RE.match(lines[0].strip()):
|
| 497 |
+
stripped = '\n'.join(lines[1:]).strip()
|
| 498 |
+
return stripped if len(stripped) > 20 else text
|
| 499 |
+
return text
|
| 500 |
+
|
| 501 |
+
# Fix 6: Label guardrails — keyword validation for high-confidence labels
|
| 502 |
+
_LABEL_GUARDRAILS = {
|
| 503 |
+
"Liquidated Damages": re.compile(
|
| 504 |
+
r'liquidated|pre-?determined.{0,10}damage|agreed.{0,10}sum|penalty clause|stipulated.{0,10}damage',
|
| 505 |
+
re.IGNORECASE
|
| 506 |
+
),
|
| 507 |
+
"Uncapped Liability": re.compile(
|
| 508 |
+
r'uncapped|unlimited.{0,10}liabilit|no.{0,10}(limit|cap).{0,10}liabilit',
|
| 509 |
+
re.IGNORECASE
|
| 510 |
+
),
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
def _apply_guardrails(label, text, confidence):
|
| 514 |
+
"""Fix 6: If label has a guardrail and text lacks required keywords, demote."""
|
| 515 |
+
guard = _LABEL_GUARDRAILS.get(label)
|
| 516 |
+
if guard and not guard.search(text):
|
| 517 |
+
return "Other", confidence * 0.3 # demote to Other with reduced confidence
|
| 518 |
+
return label, confidence
|
| 519 |
+
|
| 520 |
def _text_hash(text):
|
| 521 |
return hashlib.md5(text.encode()).hexdigest()
|
| 522 |
|
|
|
|
| 527 |
if cuad_model is None or cuad_tokenizer is None:
|
| 528 |
return _classify_regex(clause_text)
|
| 529 |
|
| 530 |
+
# Fix 3: Strip section headings before classification
|
| 531 |
+
clean_text = _strip_heading(clause_text)
|
| 532 |
+
|
| 533 |
# Check cache
|
| 534 |
+
h = _text_hash(clean_text[:512])
|
| 535 |
if h in _prediction_cache:
|
| 536 |
return _prediction_cache[h]
|
| 537 |
|
| 538 |
try:
|
| 539 |
inputs = cuad_tokenizer(
|
| 540 |
+
clean_text,
|
| 541 |
return_tensors="pt",
|
| 542 |
truncation=True,
|
| 543 |
max_length=256,
|
|
|
|
| 554 |
threshold = _CUAD_THRESHOLDS.get(i, 0.40)
|
| 555 |
if float(prob) > threshold and i < len(CUAD_LABELS):
|
| 556 |
label = CUAD_LABELS[i]
|
| 557 |
+
conf = float(prob)
|
| 558 |
+
# Fix 6: Apply guardrails — reject high-confidence false positives
|
| 559 |
+
label, conf = _apply_guardrails(label, clause_text, conf)
|
| 560 |
+
if label == "Other" and conf < 0.3:
|
| 561 |
+
continue # Skip demoted labels
|
| 562 |
risk = RISK_MAP.get(label, "LOW")
|
| 563 |
results.append({
|
| 564 |
"label": label,
|
| 565 |
+
"confidence": round(conf, 3),
|
| 566 |
"risk": risk,
|
| 567 |
"description": DESC_MAP.get(label, label),
|
| 568 |
"source": "ml",
|
|
|
|
| 834 |
"source": "heuristic",
|
| 835 |
})
|
| 836 |
|
| 837 |
+
# ── 2. Missing critical clauses (Fix 4: check raw_text, not labels) ──
|
| 838 |
+
_REQUIRED_CLAUSE_PATTERNS = {
|
| 839 |
+
"Governing Law": re.compile(
|
| 840 |
+
r'govern(?:ed|ing).{0,15}law|applicable.{0,10}law|laws?\s+of\s+the\s+state',
|
| 841 |
+
re.IGNORECASE
|
| 842 |
+
),
|
| 843 |
+
"Limitation of liability": re.compile(
|
| 844 |
+
r'limitation.{0,10}liabilit|cap.{0,10}liabilit|liabilit.{0,10}shall\s+not\s+exceed|in\s+no\s+event.{0,20}liable',
|
| 845 |
+
re.IGNORECASE
|
| 846 |
+
),
|
| 847 |
+
"Arbitration": re.compile(
|
| 848 |
+
r'arbitrat|AAA|JAMS|binding.{0,10}dispute',
|
| 849 |
+
re.IGNORECASE
|
| 850 |
+
),
|
| 851 |
+
"Termination": re.compile(
|
| 852 |
+
r'terminat(?:e|ion|ed)|cancel(?:lation)?',
|
| 853 |
+
re.IGNORECASE
|
| 854 |
+
),
|
| 855 |
}
|
| 856 |
+
for clause_name, pattern in _REQUIRED_CLAUSE_PATTERNS.items():
|
| 857 |
+
# Check raw_text directly — it's stable and deterministic
|
| 858 |
+
if not pattern.search(raw_text):
|
| 859 |
contradictions.append({
|
| 860 |
"type": "MISSING",
|
| 861 |
+
"explanation": f"No '{clause_name}' clause detected in the document.",
|
| 862 |
"severity": "MEDIUM",
|
| 863 |
+
"clauses": [clause_name],
|
| 864 |
"source": "structural",
|
| 865 |
})
|
| 866 |
|
|
|
|
| 922 |
contradictions = detect_contradictions(clause_results, text)
|
| 923 |
risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses))
|
| 924 |
obligations = extract_obligations(text)
|
| 925 |
+
# Fix 5: Compliance runs against full raw_text (already done in compliance.py)
|
| 926 |
compliance = check_compliance(text)
|
| 927 |
+
|
| 928 |
+
# Fix 2: Compute flagged_clauses AFTER all processing is complete
|
| 929 |
+
flagged_clause_count = len(clause_results)
|
| 930 |
+
unique_flagged_texts = len(set(cr["text"] for cr in clause_results))
|
| 931 |
+
|
| 932 |
result = {
|
| 933 |
"metadata": {
|
| 934 |
"analysis_date": datetime.now().isoformat(),
|
| 935 |
"total_clauses": len(clauses),
|
| 936 |
+
"flagged_clauses": flagged_clause_count,
|
| 937 |
+
"unique_flagged": unique_flagged_texts,
|
| 938 |
"model": get_model_status_text(),
|
| 939 |
+
"text_hash": hashlib.sha256(re.sub(r'\s+', ' ', text.strip()).encode()).hexdigest()[:16],
|
| 940 |
},
|
| 941 |
"risk": {
|
| 942 |
"score": risk,
|
|
|
|
| 1202 |
def run_analysis(text):
|
| 1203 |
if not text or len(text.strip()) < 50:
|
| 1204 |
err_html = '<p style="color:#dc2626;padding:16px;">Document too short (minimum 50 characters)</p>'
|
| 1205 |
+
return [err_html] * 8 + [None, None, "", None]
|
| 1206 |
result, error = analyze_contract(text)
|
| 1207 |
if error:
|
| 1208 |
err_html = f'<p style="color:#dc2626;padding:16px;">{error}</p>'
|
| 1209 |
+
return [err_html] * 8 + [None, None, error, None]
|
| 1210 |
|
| 1211 |
# FIXED: per-session temp files
|
| 1212 |
session_id = uuid.uuid4().hex[:8]
|
|
|
|
| 1219 |
with open(csv_path, "w") as f:
|
| 1220 |
f.write(csv_content)
|
| 1221 |
|
| 1222 |
+
# Generate redline suggestions (Tier 1 template + Tier 3 LLM for critical/high)
|
| 1223 |
+
redlines = generate_redlines(result, use_llm=True)
|
| 1224 |
+
redlines_html = render_redlines_html(redlines)
|
| 1225 |
+
|
| 1226 |
return [
|
| 1227 |
render_summary(result),
|
| 1228 |
render_clause_cards(result),
|
|
|
|
| 1231 |
render_document_viewer(result),
|
| 1232 |
render_obligations_html(result.get("obligations", [])),
|
| 1233 |
render_compliance_html(result.get("compliance", {})),
|
| 1234 |
+
redlines_html,
|
| 1235 |
json_path,
|
| 1236 |
csv_path,
|
| 1237 |
"Analysis complete",
|
| 1238 |
+
result, # Store analysis result for chatbot
|
| 1239 |
]
|
| 1240 |
|
| 1241 |
def do_clear():
|
| 1242 |
+
return [""] * 8 + [None, None, "", None]
|
| 1243 |
|
| 1244 |
# ── Example contracts ──
|
| 1245 |
SPOTIFY_TOS = """By using the Spotify Service, you agree to be bound by these Terms of Use.
|
|
|
|
| 1323 |
"""
|
| 1324 |
) as demo:
|
| 1325 |
|
| 1326 |
+
# ── Shared State (for chatbot RAG) ──────────────────────────────
|
| 1327 |
+
analysis_state = gr.State(None) # Full analysis result dict
|
| 1328 |
+
chunks_state = gr.State([]) # Contract text chunks for RAG
|
| 1329 |
+
embeddings_state = gr.State(None) # Chunk embeddings (numpy array)
|
| 1330 |
+
|
| 1331 |
gr.HTML("""
|
| 1332 |
<div style="display:flex;align-items:center;justify-content:space-between;padding:12px 0;border-bottom:2px solid #e5e7eb;margin-bottom:16px;">
|
| 1333 |
<div>
|
| 1334 |
<h1 style="font-size:24px;font-weight:700;margin:0;color:#1f2937;">🛡️ ClauseGuard</h1>
|
| 1335 |
+
<p style="font-size:13px;color:#6b7280;margin:4px 0 0 0;">AI-Powered Legal Contract Analysis · 41 Clause Categories · Risk Scoring · ML NER · NLI Contradictions · Compliance · Obligations · <strong>Q&A Chatbot</strong> · <strong>Clause Redlining</strong> · <strong>OCR</strong></p>
|
| 1336 |
</div>
|
| 1337 |
+
<div style="font-size:12px;color:#9ca3af;">v4.0 · Precision Legal AI</div>
|
| 1338 |
</div>
|
| 1339 |
""")
|
| 1340 |
|
| 1341 |
+
# ── Main Tabs: Analysis vs Comparison vs Chatbot ──
|
| 1342 |
with gr.Tabs():
|
| 1343 |
|
| 1344 |
# ═══════ TAB 1: Single Contract Analysis ═══════
|
|
|
|
| 1355 |
with gr.Column(scale=3):
|
| 1356 |
text_input = gr.Textbox(
|
| 1357 |
label="📄 Contract Text",
|
| 1358 |
+
placeholder="Paste contract text here, or upload a file above...\n\n💡 Scanned PDFs are automatically processed with OCR.",
|
| 1359 |
lines=14,
|
| 1360 |
max_lines=40,
|
| 1361 |
show_copy_button=True,
|
|
|
|
| 1398 |
obligations_html = gr.HTML(label="Obligation Tracker")
|
| 1399 |
with gr.Tab("⚖️ Compliance"):
|
| 1400 |
compliance_html = gr.HTML(label="Compliance Checker")
|
| 1401 |
+
with gr.Tab("✏️ Redlining"):
|
| 1402 |
+
redlining_html = gr.HTML(label="Clause Redlining Suggestions")
|
| 1403 |
|
| 1404 |
# ═══════ TAB 2: Contract Comparison ═══════
|
| 1405 |
with gr.Tab("🔀 Compare Contracts"):
|
|
|
|
| 1448 |
with gr.Column(scale=2):
|
| 1449 |
comp_json = gr.JSON(label="Raw Comparison Data")
|
| 1450 |
|
| 1451 |
+
# ═══════ TAB 3: Contract Q&A Chatbot ═══════
|
| 1452 |
+
with gr.Tab("💬 Contract Q&A"):
|
| 1453 |
+
gr.HTML("""
|
| 1454 |
+
<div style="padding:12px 16px;background:linear-gradient(135deg,#eff6ff,#faf5ff);border-radius:10px;margin-bottom:12px;border:1px solid #e5e7eb;">
|
| 1455 |
+
<div style="display:flex;align-items:center;gap:8px;margin-bottom:6px;">
|
| 1456 |
+
<span style="font-size:20px;">💬</span>
|
| 1457 |
+
<h3 style="margin:0;font-size:16px;color:#1f2937;">Contract Q&A Chatbot</h3>
|
| 1458 |
+
</div>
|
| 1459 |
+
<p style="font-size:12px;color:#6b7280;margin:0;line-height:1.5;">
|
| 1460 |
+
Ask questions about your analyzed contract. The chatbot uses <strong>RAG</strong> (Retrieval-Augmented Generation)
|
| 1461 |
+
to find relevant clauses and generate accurate answers grounded in your contract text.
|
| 1462 |
+
<br>
|
| 1463 |
+
<strong>Step 1:</strong> Analyze a contract in the "📄 Single Contract Analysis" tab.
|
| 1464 |
+
<strong>Step 2:</strong> Come here and ask questions!
|
| 1465 |
+
</p>
|
| 1466 |
+
</div>
|
| 1467 |
+
""")
|
| 1468 |
+
|
| 1469 |
+
chatbot_index_status = gr.Textbox(
|
| 1470 |
+
label="📡 Chatbot Index Status",
|
| 1471 |
+
interactive=False,
|
| 1472 |
+
lines=1,
|
| 1473 |
+
value="⏳ No contract indexed yet — analyze a contract first",
|
| 1474 |
+
)
|
| 1475 |
+
|
| 1476 |
+
def _chatbot_fn(message, history, chunks, embeddings, analysis):
|
| 1477 |
+
"""Wrapper for ChatInterface fn signature."""
|
| 1478 |
+
yield from chat_respond(message, history, chunks, embeddings, analysis)
|
| 1479 |
+
|
| 1480 |
+
gr.ChatInterface(
|
| 1481 |
+
fn=_chatbot_fn,
|
| 1482 |
+
type="messages",
|
| 1483 |
+
additional_inputs=[chunks_state, embeddings_state, analysis_state],
|
| 1484 |
+
examples=[
|
| 1485 |
+
["What are the main risks in this contract?"],
|
| 1486 |
+
["Who are the parties involved?"],
|
| 1487 |
+
["What happens if the contract is terminated?"],
|
| 1488 |
+
["Are there any liability limitations?"],
|
| 1489 |
+
["What are my obligations under this contract?"],
|
| 1490 |
+
["Is there an arbitration clause?"],
|
| 1491 |
+
["What is the governing law?"],
|
| 1492 |
+
["Summarize the key terms in plain language."],
|
| 1493 |
+
],
|
| 1494 |
+
title="",
|
| 1495 |
+
description="",
|
| 1496 |
+
)
|
| 1497 |
+
|
| 1498 |
# ── Events ──
|
| 1499 |
def _load_file(file):
|
| 1500 |
text, err = parse_document(file) if file else ("", "No file")
|
|
|
|
| 1502 |
return "", err
|
| 1503 |
return text, "Loaded successfully" if not err else err
|
| 1504 |
|
| 1505 |
+
def _analysis_and_index(text):
|
| 1506 |
+
"""Run analysis AND index for chatbot in one call."""
|
| 1507 |
+
# Run the standard analysis
|
| 1508 |
+
analysis_outputs = run_analysis(text)
|
| 1509 |
+
|
| 1510 |
+
# Index for chatbot (uses the raw text)
|
| 1511 |
+
chunks, embeddings, index_status = index_contract(text)
|
| 1512 |
+
|
| 1513 |
+
# analysis_outputs has 12 items: 8 HTML + json_path + csv_path + status + result
|
| 1514 |
+
# We need to add: chunks_state, embeddings_state, chatbot_index_status
|
| 1515 |
+
return analysis_outputs + [chunks, embeddings, index_status]
|
| 1516 |
+
|
| 1517 |
load_btn.click(_load_file, inputs=[file_input], outputs=[text_input, load_status])
|
| 1518 |
comp_load_a.click(_load_file, inputs=[comp_file_a], outputs=[comp_text_a, comp_status_a])
|
| 1519 |
comp_load_b.click(_load_file, inputs=[comp_file_b], outputs=[comp_text_b, comp_status_b])
|
| 1520 |
|
| 1521 |
scan_btn.click(
|
| 1522 |
+
_analysis_and_index,
|
| 1523 |
inputs=[text_input],
|
| 1524 |
+
outputs=[
|
| 1525 |
+
summary_html, clauses_html, entities_html, nli_html,
|
| 1526 |
+
doc_html, obligations_html, compliance_html, redlining_html,
|
| 1527 |
+
json_file, csv_file, status_msg, analysis_state,
|
| 1528 |
+
chunks_state, embeddings_state, chatbot_index_status,
|
| 1529 |
+
]
|
| 1530 |
)
|
| 1531 |
|
| 1532 |
clear_btn.click(
|
| 1533 |
+
lambda: [""] * 8 + [None, None, "", None, [], None, "⏳ No contract indexed"],
|
| 1534 |
+
outputs=[
|
| 1535 |
+
summary_html, clauses_html, entities_html, nli_html,
|
| 1536 |
+
doc_html, obligations_html, compliance_html, redlining_html,
|
| 1537 |
+
json_file, csv_file, status_msg, analysis_state,
|
| 1538 |
+
chunks_state, embeddings_state, chatbot_index_status,
|
| 1539 |
+
]
|
| 1540 |
)
|
| 1541 |
|
| 1542 |
comp_btn.click(
|
|
|
|
| 1552 |
· Model: <a href="https://huggingface.co/Mokshith31/legalbert-contract-clause-classification" style="color:#6b7280;">Legal-BERT + CUAD (41 classes)</a>
|
| 1553 |
· NER: <a href="https://huggingface.co/matterstack/legal-bert-ner" style="color:#6b7280;">Legal-BERT NER</a>
|
| 1554 |
· NLI: <a href="https://huggingface.co/cross-encoder/nli-deberta-v3-base" style="color:#6b7280;">DeBERTa-v3 NLI</a>
|
| 1555 |
+
· LLM: <a href="https://huggingface.co/Qwen/Qwen2.5-7B-Instruct" style="color:#6b7280;">Qwen2.5-7B</a>
|
| 1556 |
+
· OCR: <a href="https://github.com/mindee/doctr" style="color:#6b7280;">docTR</a>
|
| 1557 |
· Dataset: <a href="https://huggingface.co/datasets/theatticusproject/cuad-qa" style="color:#6b7280;">CUAD</a>
|
| 1558 |
· <a href="https://huggingface.co/spaces/gaurv007/ClauseGuard" style="color:#6b7280;">ClauseGuard Space</a>
|
| 1559 |
</p>
|
chatbot.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ClauseGuard — Contract Q&A Chatbot (RAG) v1.0
|
| 3 |
+
═══════════════════════════════════════════════
|
| 4 |
+
Architecture:
|
| 5 |
+
User asks question about their contract
|
| 6 |
+
↓
|
| 7 |
+
[1] Embed question with sentence-transformers (all-MiniLM-L6-v2)
|
| 8 |
+
↓
|
| 9 |
+
[2] Retrieve top-5 most relevant chunks from contract
|
| 10 |
+
↓
|
| 11 |
+
[3] Build prompt:
|
| 12 |
+
- System: ClauseGuard analysis results (clauses, entities, risk scores)
|
| 13 |
+
- Context: Retrieved contract chunks (≤2.5K tokens)
|
| 14 |
+
- User question
|
| 15 |
+
↓
|
| 16 |
+
[4] Stream response from LLM via HF Inference API
|
| 17 |
+
|
| 18 |
+
Key design:
|
| 19 |
+
• Analyzed data (clauses, entities, risk scores) → system prompt
|
| 20 |
+
• Raw contract text → RAG retrieval
|
| 21 |
+
• This gives the model both structured analysis AND verbatim evidence
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
# ── Embedding model (soft-fail) ─────────────────────────────────────
|
| 29 |
+
_HAS_EMBEDDER = False
|
| 30 |
+
_embedder = None
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from sentence_transformers import SentenceTransformer
|
| 34 |
+
_HAS_EMBEDDER = True
|
| 35 |
+
except ImportError:
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
# ── HF Inference Client (soft-fail) ─────────────────────────────────
|
| 39 |
+
_HAS_INFERENCE = False
|
| 40 |
+
_llm_client = None
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from huggingface_hub import InferenceClient
|
| 44 |
+
_HAS_INFERENCE = True
|
| 45 |
+
except ImportError:
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 49 |
+
# MODEL LOADING
|
| 50 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 51 |
+
|
| 52 |
+
_chatbot_status = {"embedder": "not_loaded", "llm": "not_loaded"}
|
| 53 |
+
|
| 54 |
+
def _load_embedder():
|
| 55 |
+
"""Load sentence-transformers embedding model (lazy)."""
|
| 56 |
+
global _embedder, _chatbot_status
|
| 57 |
+
if _embedder is not None:
|
| 58 |
+
return _embedder
|
| 59 |
+
if not _HAS_EMBEDDER:
|
| 60 |
+
_chatbot_status["embedder"] = "unavailable"
|
| 61 |
+
return None
|
| 62 |
+
try:
|
| 63 |
+
print("[ClauseGuard Chat] Loading embedding model: all-MiniLM-L6-v2...")
|
| 64 |
+
_embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 65 |
+
_chatbot_status["embedder"] = "loaded"
|
| 66 |
+
print("[ClauseGuard Chat] Embedding model loaded")
|
| 67 |
+
return _embedder
|
| 68 |
+
except Exception as e:
|
| 69 |
+
_chatbot_status["embedder"] = f"failed: {e}"
|
| 70 |
+
print(f"[ClauseGuard Chat] Embedder load failed: {e}")
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _get_llm_client():
|
| 75 |
+
"""Get or create HF Inference Client (lazy)."""
|
| 76 |
+
global _llm_client, _chatbot_status
|
| 77 |
+
if _llm_client is not None:
|
| 78 |
+
return _llm_client
|
| 79 |
+
if not _HAS_INFERENCE:
|
| 80 |
+
_chatbot_status["llm"] = "unavailable"
|
| 81 |
+
return None
|
| 82 |
+
try:
|
| 83 |
+
token = os.environ.get("HF_TOKEN", "")
|
| 84 |
+
_llm_client = InferenceClient(
|
| 85 |
+
provider="hf-inference",
|
| 86 |
+
api_key=token if token else None,
|
| 87 |
+
)
|
| 88 |
+
_chatbot_status["llm"] = "loaded"
|
| 89 |
+
print("[ClauseGuard Chat] HF Inference Client initialized")
|
| 90 |
+
return _llm_client
|
| 91 |
+
except Exception as e:
|
| 92 |
+
_chatbot_status["llm"] = f"failed: {e}"
|
| 93 |
+
print(f"[ClauseGuard Chat] LLM client init failed: {e}")
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_chatbot_status():
|
| 98 |
+
"""Return human-readable chatbot status."""
|
| 99 |
+
parts = []
|
| 100 |
+
for name, status in _chatbot_status.items():
|
| 101 |
+
icon = "✅" if status == "loaded" else "⚠️" if "failed" in status else "❌"
|
| 102 |
+
label = {"embedder": "Embeddings", "llm": "LLM API"}[name]
|
| 103 |
+
parts.append(f"{icon} {label}: {status}")
|
| 104 |
+
return " · ".join(parts)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 108 |
+
# TEXT CHUNKING (sentence-preserving, ~300 tokens, no overlap)
|
| 109 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 110 |
+
|
| 111 |
+
def chunk_contract_text(text, target_chunk_size=300, min_chunk_size=50):
|
| 112 |
+
"""
|
| 113 |
+
Split contract text into chunks for RAG retrieval.
|
| 114 |
+
Sentence-preserving, ~300 tokens per chunk, 0% overlap.
|
| 115 |
+
Research (arxiv 2601.14123): overlap adds cost with zero benefit.
|
| 116 |
+
"""
|
| 117 |
+
if not text:
|
| 118 |
+
return []
|
| 119 |
+
|
| 120 |
+
# First split on paragraph boundaries
|
| 121 |
+
paragraphs = re.split(r'\n\n+', text)
|
| 122 |
+
chunks = []
|
| 123 |
+
current_chunk = ""
|
| 124 |
+
|
| 125 |
+
for para in paragraphs:
|
| 126 |
+
para = para.strip()
|
| 127 |
+
if not para:
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
# Estimate word count (rough token proxy)
|
| 131 |
+
words_current = len(current_chunk.split())
|
| 132 |
+
words_para = len(para.split())
|
| 133 |
+
|
| 134 |
+
if words_current + words_para <= target_chunk_size:
|
| 135 |
+
current_chunk += ("\n\n" + para if current_chunk else para)
|
| 136 |
+
else:
|
| 137 |
+
# Current chunk is full enough — save it
|
| 138 |
+
if words_current >= min_chunk_size:
|
| 139 |
+
chunks.append(current_chunk.strip())
|
| 140 |
+
current_chunk = para
|
| 141 |
+
else:
|
| 142 |
+
# Current chunk too small — need to split the paragraph into sentences
|
| 143 |
+
sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', para)
|
| 144 |
+
for sent in sentences:
|
| 145 |
+
words_current = len(current_chunk.split())
|
| 146 |
+
words_sent = len(sent.split())
|
| 147 |
+
if words_current + words_sent <= target_chunk_size:
|
| 148 |
+
current_chunk += (" " + sent if current_chunk else sent)
|
| 149 |
+
else:
|
| 150 |
+
if words_current >= min_chunk_size:
|
| 151 |
+
chunks.append(current_chunk.strip())
|
| 152 |
+
current_chunk = sent
|
| 153 |
+
|
| 154 |
+
# Don't forget the last chunk
|
| 155 |
+
if current_chunk.strip() and len(current_chunk.split()) >= min_chunk_size:
|
| 156 |
+
chunks.append(current_chunk.strip())
|
| 157 |
+
|
| 158 |
+
return chunks
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 162 |
+
# EMBEDDING & RETRIEVAL
|
| 163 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 164 |
+
|
| 165 |
+
def build_embeddings(chunks):
|
| 166 |
+
"""
|
| 167 |
+
Embed chunks using sentence-transformers.
|
| 168 |
+
Returns numpy array of shape (N, 384) or None if embedder unavailable.
|
| 169 |
+
"""
|
| 170 |
+
embedder = _load_embedder()
|
| 171 |
+
if embedder is None or not chunks:
|
| 172 |
+
return None
|
| 173 |
+
try:
|
| 174 |
+
embeddings = embedder.encode(
|
| 175 |
+
chunks,
|
| 176 |
+
normalize_embeddings=True,
|
| 177 |
+
batch_size=32,
|
| 178 |
+
show_progress_bar=False,
|
| 179 |
+
)
|
| 180 |
+
return embeddings # numpy array (N, 384)
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(f"[ClauseGuard Chat] Embedding error: {e}")
|
| 183 |
+
return None
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def retrieve_chunks(query, chunks, embeddings, top_k=5):
|
| 187 |
+
"""
|
| 188 |
+
Retrieve top-k most relevant chunks for a query.
|
| 189 |
+
Uses cosine similarity (embeddings are L2-normalized → dot product = cosine).
|
| 190 |
+
Context budget: top-5 chunks, ≤2.5K tokens.
|
| 191 |
+
"""
|
| 192 |
+
embedder = _load_embedder()
|
| 193 |
+
if embedder is None or embeddings is None or not chunks:
|
| 194 |
+
return []
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
q_emb = embedder.encode([query], normalize_embeddings=True)
|
| 198 |
+
scores = (q_emb @ embeddings.T)[0]
|
| 199 |
+
top_indices = np.argsort(scores)[::-1][:top_k]
|
| 200 |
+
|
| 201 |
+
results = []
|
| 202 |
+
total_words = 0
|
| 203 |
+
max_words = 600 # ~2.5K tokens budget
|
| 204 |
+
|
| 205 |
+
for idx in top_indices:
|
| 206 |
+
chunk = chunks[idx]
|
| 207 |
+
chunk_words = len(chunk.split())
|
| 208 |
+
if total_words + chunk_words > max_words and results:
|
| 209 |
+
break
|
| 210 |
+
results.append({
|
| 211 |
+
"text": chunk,
|
| 212 |
+
"score": float(scores[idx]),
|
| 213 |
+
"index": int(idx),
|
| 214 |
+
})
|
| 215 |
+
total_words += chunk_words
|
| 216 |
+
|
| 217 |
+
return results
|
| 218 |
+
except Exception as e:
|
| 219 |
+
print(f"[ClauseGuard Chat] Retrieval error: {e}")
|
| 220 |
+
return []
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 224 |
+
# SYSTEM PROMPT BUILDER
|
| 225 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 226 |
+
|
| 227 |
+
def _build_system_prompt(analysis_result, retrieved_chunks):
|
| 228 |
+
"""
|
| 229 |
+
Build the system prompt with:
|
| 230 |
+
1. ClauseGuard analysis results (clauses, entities, risk scores) — NOT through RAG
|
| 231 |
+
2. Retrieved contract chunks — through RAG
|
| 232 |
+
"""
|
| 233 |
+
parts = []
|
| 234 |
+
|
| 235 |
+
parts.append("""You are ClauseGuard AI, a legal contract analysis assistant. You help users understand their contracts by answering questions based on the contract text and analysis results.
|
| 236 |
+
|
| 237 |
+
RULES:
|
| 238 |
+
- Answer ONLY based on the provided contract text and analysis. Never make up information.
|
| 239 |
+
- If the answer isn't in the provided context, say "I don't see that information in the analyzed contract."
|
| 240 |
+
- Cite specific clauses or sections when possible.
|
| 241 |
+
- Be concise but thorough. Use plain language, not legal jargon.
|
| 242 |
+
- Always end with: "⚠️ This is AI analysis, not legal advice. Consult an attorney for legal decisions."
|
| 243 |
+
""")
|
| 244 |
+
|
| 245 |
+
# Add analysis summary if available
|
| 246 |
+
if analysis_result:
|
| 247 |
+
risk = analysis_result.get("risk", {})
|
| 248 |
+
parts.append(f"""
|
| 249 |
+
═��═ CONTRACT ANALYSIS SUMMARY ═══
|
| 250 |
+
Risk Score: {risk.get('score', 'N/A')}/100 (Grade {risk.get('grade', 'N/A')})
|
| 251 |
+
Risk Breakdown: {risk.get('breakdown', {})}
|
| 252 |
+
Total Clauses Analyzed: {analysis_result.get('metadata', {}).get('total_clauses', 'N/A')}
|
| 253 |
+
Flagged Clauses: {analysis_result.get('metadata', {}).get('flagged_clauses', 'N/A')}
|
| 254 |
+
""")
|
| 255 |
+
|
| 256 |
+
# Add detected clauses summary
|
| 257 |
+
clauses = analysis_result.get("clauses", [])
|
| 258 |
+
if clauses:
|
| 259 |
+
clause_summary = []
|
| 260 |
+
seen = set()
|
| 261 |
+
for c in clauses:
|
| 262 |
+
key = c["label"]
|
| 263 |
+
if key not in seen:
|
| 264 |
+
seen.add(key)
|
| 265 |
+
risk_level = c.get("risk", "LOW")
|
| 266 |
+
clause_summary.append(f" • [{risk_level}] {key}: {c.get('description', '')}")
|
| 267 |
+
parts.append("═══ DETECTED CLAUSES ═══\n" + "\n".join(clause_summary[:20]))
|
| 268 |
+
|
| 269 |
+
# Add entities summary
|
| 270 |
+
entities = analysis_result.get("entities", [])
|
| 271 |
+
if entities:
|
| 272 |
+
entity_summary = []
|
| 273 |
+
seen = set()
|
| 274 |
+
for e in entities:
|
| 275 |
+
key = f"{e['type']}: {e['text']}"
|
| 276 |
+
if key not in seen and len(seen) < 15:
|
| 277 |
+
seen.add(key)
|
| 278 |
+
entity_summary.append(f" • {e['type']}: {e['text']}")
|
| 279 |
+
parts.append("═══ EXTRACTED ENTITIES ═══\n" + "\n".join(entity_summary))
|
| 280 |
+
|
| 281 |
+
# Add contradictions
|
| 282 |
+
contradictions = analysis_result.get("contradictions", [])
|
| 283 |
+
if contradictions:
|
| 284 |
+
contra_summary = []
|
| 285 |
+
for c in contradictions:
|
| 286 |
+
contra_summary.append(f" • [{c['type']}] {c['explanation']}")
|
| 287 |
+
parts.append("═══ CONTRADICTIONS / ISSUES ═══\n" + "\n".join(contra_summary))
|
| 288 |
+
|
| 289 |
+
# Add retrieved contract text
|
| 290 |
+
if retrieved_chunks:
|
| 291 |
+
context_text = "\n---\n".join(c["text"] for c in retrieved_chunks)
|
| 292 |
+
parts.append(f"""
|
| 293 |
+
═══ RELEVANT CONTRACT TEXT (Retrieved) ═══
|
| 294 |
+
{context_text}
|
| 295 |
+
""")
|
| 296 |
+
|
| 297 |
+
return "\n\n".join(parts)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 301 |
+
# CHAT RESPONSE (Streaming)
|
| 302 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 303 |
+
|
| 304 |
+
# LLM model to use
|
| 305 |
+
_LLM_MODEL = "Qwen/Qwen2.5-7B-Instruct"
|
| 306 |
+
|
| 307 |
+
def chat_respond(message, history, chunks, embeddings, analysis_result):
|
| 308 |
+
"""
|
| 309 |
+
RAG chatbot response function for gr.ChatInterface.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
message: User's question (str)
|
| 313 |
+
history: Chat history (list of dicts with role/content)
|
| 314 |
+
chunks: Contract text chunks (list of str)
|
| 315 |
+
embeddings: Chunk embeddings (numpy array or None)
|
| 316 |
+
analysis_result: Full analysis result dict (or None)
|
| 317 |
+
|
| 318 |
+
Yields:
|
| 319 |
+
Partial response string (streaming)
|
| 320 |
+
"""
|
| 321 |
+
# Validate inputs
|
| 322 |
+
if not chunks or embeddings is None:
|
| 323 |
+
yield ("⚠️ No contract loaded yet. Please upload and analyze a contract in the "
|
| 324 |
+
"**📄 Single Contract Analysis** tab first, then come back here to ask questions.")
|
| 325 |
+
return
|
| 326 |
+
|
| 327 |
+
if not message or not message.strip():
|
| 328 |
+
yield "Please ask a question about your contract."
|
| 329 |
+
return
|
| 330 |
+
|
| 331 |
+
# Step 1: Retrieve relevant chunks
|
| 332 |
+
retrieved = retrieve_chunks(message, chunks, embeddings, top_k=5)
|
| 333 |
+
|
| 334 |
+
# Step 2: Build system prompt with analysis + retrieved context
|
| 335 |
+
system_prompt = _build_system_prompt(analysis_result, retrieved)
|
| 336 |
+
|
| 337 |
+
# Step 3: Build message history for LLM
|
| 338 |
+
messages = [{"role": "system", "content": system_prompt}]
|
| 339 |
+
|
| 340 |
+
# Add recent history (last 6 turns to stay in context window)
|
| 341 |
+
if history:
|
| 342 |
+
for h in history[-6:]:
|
| 343 |
+
messages.append({"role": h["role"], "content": h["content"]})
|
| 344 |
+
|
| 345 |
+
messages.append({"role": "user", "content": message})
|
| 346 |
+
|
| 347 |
+
# Step 4: Stream response from LLM
|
| 348 |
+
client = _get_llm_client()
|
| 349 |
+
if client is None:
|
| 350 |
+
yield ("⚠️ LLM service unavailable. Please ensure `huggingface_hub` is installed "
|
| 351 |
+
"and `HF_TOKEN` is set.")
|
| 352 |
+
return
|
| 353 |
+
|
| 354 |
+
try:
|
| 355 |
+
stream = client.chat_completion(
|
| 356 |
+
model=_LLM_MODEL,
|
| 357 |
+
messages=messages,
|
| 358 |
+
max_tokens=1024,
|
| 359 |
+
stream=True,
|
| 360 |
+
temperature=0.3, # Low temperature for factual responses
|
| 361 |
+
)
|
| 362 |
+
partial = ""
|
| 363 |
+
for chunk in stream:
|
| 364 |
+
token = chunk.choices[0].delta.content or ""
|
| 365 |
+
partial += token
|
| 366 |
+
yield partial
|
| 367 |
+
except Exception as e:
|
| 368 |
+
error_msg = str(e)
|
| 369 |
+
if "rate limit" in error_msg.lower() or "429" in error_msg:
|
| 370 |
+
yield ("⚠️ Rate limit reached on the free HF Inference API. "
|
| 371 |
+
"Please wait a moment and try again.")
|
| 372 |
+
elif "401" in error_msg or "unauthorized" in error_msg.lower():
|
| 373 |
+
yield ("⚠️ Authentication error. Please set your HF_TOKEN in the Space settings.")
|
| 374 |
+
else:
|
| 375 |
+
yield f"⚠️ Error generating response: {error_msg}\n\nPlease try again."
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 379 |
+
# INDEXING HELPER (combines chunking + embedding)
|
| 380 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 381 |
+
|
| 382 |
+
def index_contract(text):
|
| 383 |
+
"""
|
| 384 |
+
Chunk and embed contract text for RAG retrieval.
|
| 385 |
+
|
| 386 |
+
Returns: (chunks, embeddings, status_message)
|
| 387 |
+
chunks: list of str
|
| 388 |
+
embeddings: numpy array or None
|
| 389 |
+
status_message: str
|
| 390 |
+
"""
|
| 391 |
+
if not text or len(text.strip()) < 50:
|
| 392 |
+
return [], None, "⚠️ No contract text to index"
|
| 393 |
+
|
| 394 |
+
chunks = chunk_contract_text(text)
|
| 395 |
+
if not chunks:
|
| 396 |
+
return [], None, "⚠️ Could not split contract into chunks"
|
| 397 |
+
|
| 398 |
+
embeddings = build_embeddings(chunks)
|
| 399 |
+
if embeddings is None:
|
| 400 |
+
return chunks, None, "⚠️ Embedding model unavailable — chatbot will not work"
|
| 401 |
+
|
| 402 |
+
return (
|
| 403 |
+
chunks,
|
| 404 |
+
embeddings,
|
| 405 |
+
f"✅ Indexed {len(chunks)} chunks ({len(text)} chars) — Ready to chat!"
|
| 406 |
+
)
|
compare.py
CHANGED
|
@@ -98,6 +98,28 @@ def compare_contracts(text_a, text_b, clauses_a=None, clauses_b=None):
|
|
| 98 |
if clauses_b is None:
|
| 99 |
clauses_b = _split_clauses(text_b)
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
# Build clause type maps
|
| 102 |
type_map_a = defaultdict(list)
|
| 103 |
type_map_b = defaultdict(list)
|
|
@@ -111,8 +133,9 @@ def compare_contracts(text_a, text_b, clauses_a=None, clauses_b=None):
|
|
| 111 |
matched_b = set()
|
| 112 |
modified = []
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
|
|
|
| 116 |
|
| 117 |
for i, ca in enumerate(clauses_a):
|
| 118 |
best_sim = 0
|
|
@@ -181,12 +204,20 @@ def compare_contracts(text_a, text_b, clauses_a=None, clauses_b=None):
|
|
| 181 |
risk_delta = "Similar risk profiles"
|
| 182 |
risk_winner = "tie"
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
comparison_method = "semantic (sentence embeddings)" if _embedder is not None else "lexical (string matching)"
|
| 185 |
|
| 186 |
return {
|
| 187 |
"alignment_score": round(alignment, 3),
|
| 188 |
"contract_a_clauses": len(clauses_a),
|
| 189 |
"contract_b_clauses": len(clauses_b),
|
|
|
|
|
|
|
|
|
|
| 190 |
"added_clauses": [{"text": c[:200], "type": _extract_clause_type(c)} for c in added[:50]],
|
| 191 |
"removed_clauses": [{"text": c[:200], "type": _extract_clause_type(c)} for c in removed[:50]],
|
| 192 |
"modified_clauses": modified[:50],
|
|
|
|
| 98 |
if clauses_b is None:
|
| 99 |
clauses_b = _split_clauses(text_b)
|
| 100 |
|
| 101 |
+
# Fix 9: Detect contract types and flag cross-domain comparisons
|
| 102 |
+
_CONTRACT_TYPE_KEYWORDS = {
|
| 103 |
+
"employment": ["employee", "employer", "salary", "compensation", "benefits", "vacation", "severance", "at-will"],
|
| 104 |
+
"lease": ["landlord", "tenant", "rent", "premises", "lease", "occupancy", "security deposit", "eviction"],
|
| 105 |
+
"service": ["service provider", "customer", "SLA", "deliverables", "statement of work", "SOW"],
|
| 106 |
+
"nda": ["confidential", "non-disclosure", "disclosing party", "receiving party"],
|
| 107 |
+
"saas": ["subscription", "SaaS", "cloud", "uptime", "API", "data processing"],
|
| 108 |
+
"purchase": ["buyer", "seller", "purchase order", "goods", "shipment", "delivery"],
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
def _detect_contract_type(text):
|
| 112 |
+
text_lower = text.lower()
|
| 113 |
+
scores = {}
|
| 114 |
+
for ctype, keywords in _CONTRACT_TYPE_KEYWORDS.items():
|
| 115 |
+
scores[ctype] = sum(1 for kw in keywords if kw.lower() in text_lower)
|
| 116 |
+
best = max(scores, key=scores.get)
|
| 117 |
+
return best if scores[best] >= 2 else "general"
|
| 118 |
+
|
| 119 |
+
type_a = _detect_contract_type(text_a)
|
| 120 |
+
type_b = _detect_contract_type(text_b)
|
| 121 |
+
is_cross_domain = type_a != type_b and type_a != "general" and type_b != "general"
|
| 122 |
+
|
| 123 |
# Build clause type maps
|
| 124 |
type_map_a = defaultdict(list)
|
| 125 |
type_map_b = defaultdict(list)
|
|
|
|
| 133 |
matched_b = set()
|
| 134 |
modified = []
|
| 135 |
|
| 136 |
+
# Fix 10: Raise thresholds to reject false "modified" matches
|
| 137 |
+
SIMILARITY_THRESHOLD = 0.75 # was 0.70 — too many false matches
|
| 138 |
+
MODIFIED_THRESHOLD = 0.55 # was 0.40 — "Good Reason" ≠ "Force Majeure"
|
| 139 |
|
| 140 |
for i, ca in enumerate(clauses_a):
|
| 141 |
best_sim = 0
|
|
|
|
| 204 |
risk_delta = "Similar risk profiles"
|
| 205 |
risk_winner = "tie"
|
| 206 |
|
| 207 |
+
# Fix 9: Cross-domain warning
|
| 208 |
+
if is_cross_domain:
|
| 209 |
+
risk_delta = f"Cross-domain comparison ({type_a} vs {type_b}) — risk delta not meaningful across different contract types"
|
| 210 |
+
risk_winner = "cross-domain"
|
| 211 |
+
|
| 212 |
comparison_method = "semantic (sentence embeddings)" if _embedder is not None else "lexical (string matching)"
|
| 213 |
|
| 214 |
return {
|
| 215 |
"alignment_score": round(alignment, 3),
|
| 216 |
"contract_a_clauses": len(clauses_a),
|
| 217 |
"contract_b_clauses": len(clauses_b),
|
| 218 |
+
"contract_a_type": type_a,
|
| 219 |
+
"contract_b_type": type_b,
|
| 220 |
+
"is_cross_domain": is_cross_domain,
|
| 221 |
"added_clauses": [{"text": c[:200], "type": _extract_clause_type(c)} for c in added[:50]],
|
| 222 |
"removed_clauses": [{"text": c[:200], "type": _extract_clause_type(c)} for c in removed[:50]],
|
| 223 |
"modified_clauses": modified[:50],
|
ml/ClauseGuard_DeBERTa_Training.ipynb
ADDED
|
@@ -0,0 +1,1041 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"language_info": {
|
| 14 |
+
"name": "python"
|
| 15 |
+
},
|
| 16 |
+
"accelerator": "GPU"
|
| 17 |
+
},
|
| 18 |
+
"cells": [
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "markdown",
|
| 21 |
+
"source": [
|
| 22 |
+
"# 🛡️ ClauseGuard v4 — DeBERTa-v3-large 2-Stage Training\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"**Goal:** Train a production-grade contract clause classifier that replaces the current Legal-BERT-base (50% F1 → target 80-87% F1)\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"## Architecture\n",
|
| 27 |
+
"| Setting | Value | Source |\n",
|
| 28 |
+
"|---------|-------|--------|\n",
|
| 29 |
+
"| Base model | `microsoft/deberta-v3-large` (435M params) | LexGLUE: outperforms Legal-BERT by 7-10pp |\n",
|
| 30 |
+
"| Max length | 512 tokens | MAUD paper: covers 72.4% of clauses without truncation |\n",
|
| 31 |
+
"| Loss function | Asymmetric Loss (γ-=4, clip=0.05) | ASL paper (2009.14119): +3-8pp on rare classes |\n",
|
| 32 |
+
"| Training | Full fine-tuning (no LoRA) | Full FT wins for encoder classification |\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"## 2-Stage Training Pipeline\n",
|
| 35 |
+
"1. **Stage 1 — LEDGAR** (60K legal provisions, 100 classes): Teaches \"what types of contract clauses exist\"\n",
|
| 36 |
+
"2. **Stage 2 — CUAD** (41 CUAD classes): Target task with Asymmetric Loss for class imbalance\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"**Runtime:** ~8-12 hours on T4 GPU (or ~4-6 hours on A100)\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"**Before running:**\n",
|
| 41 |
+
"1. `Runtime` → `Change runtime type` → **T4 GPU**\n",
|
| 42 |
+
"2. `Runtime` → `Run all`\n",
|
| 43 |
+
"3. Paste your HuggingFace token when prompted"
|
| 44 |
+
],
|
| 45 |
+
"metadata": {}
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "markdown",
|
| 49 |
+
"source": [
|
| 50 |
+
"## Step 1: Install Dependencies"
|
| 51 |
+
],
|
| 52 |
+
"metadata": {}
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"cell_type": "code",
|
| 56 |
+
"source": [
|
| 57 |
+
"!pip install -q transformers datasets scikit-learn accelerate huggingface_hub torch\n",
|
| 58 |
+
"!pip install -q trackio # optional: experiment tracking"
|
| 59 |
+
],
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"execution_count": null,
|
| 62 |
+
"outputs": []
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "markdown",
|
| 66 |
+
"source": [
|
| 67 |
+
"## Step 2: Login to HuggingFace Hub"
|
| 68 |
+
],
|
| 69 |
+
"metadata": {}
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"cell_type": "code",
|
| 73 |
+
"source": [
|
| 74 |
+
"from huggingface_hub import login\n",
|
| 75 |
+
"login()"
|
| 76 |
+
],
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"execution_count": null,
|
| 79 |
+
"outputs": []
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "markdown",
|
| 83 |
+
"source": [
|
| 84 |
+
"## Step 3: Configuration"
|
| 85 |
+
],
|
| 86 |
+
"metadata": {}
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"cell_type": "code",
|
| 90 |
+
"source": [
|
| 91 |
+
"import os\n",
|
| 92 |
+
"import torch\n",
|
| 93 |
+
"import numpy as np\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"# ═══════════════════════════════════════════════════════════════\n",
|
| 96 |
+
"# CONFIGURATION — Edit these values\n",
|
| 97 |
+
"# ═══════════════════════════════════════════════════════════════\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"BASE_MODEL = \"microsoft/deberta-v3-large\" # 435M params, MIT license\n",
|
| 100 |
+
"MAX_LENGTH = 512 # covers 72.4% of clauses\n",
|
| 101 |
+
"HUB_MODEL_ID = \"gaurv007/clauseguard-deberta-v3-large\" # ← your model repo\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"# Stage 1: LEDGAR config\n",
|
| 104 |
+
"STAGE1_EPOCHS = 5 # LEDGAR is large, converges fast\n",
|
| 105 |
+
"STAGE1_LR = 2e-5\n",
|
| 106 |
+
"STAGE1_BATCH = 2 # T4 fp32: reduced for DeBERTa-v3 compatibility\n",
|
| 107 |
+
"STAGE1_GRAD_ACCUM = 16 # effective batch = 32 (2 * 16)\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"# Stage 2: CUAD config \n",
|
| 110 |
+
"STAGE2_EPOCHS = 20\n",
|
| 111 |
+
"STAGE2_LR = 1e-5 # lower LR for fine-tuning pretrained model\n",
|
| 112 |
+
"STAGE2_BATCH = 2 # T4 fp32: reduced for DeBERTa-v3 compatibility\n",
|
| 113 |
+
"STAGE2_GRAD_ACCUM = 16 # effective batch = 32 (2 * 16)\n",
|
| 114 |
+
"EARLY_STOPPING_PATIENCE = 3\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"# ASL hyperparameters (from arxiv 2009.14119)\n",
|
| 117 |
+
"ASL_GAMMA_POS = 0\n",
|
| 118 |
+
"ASL_GAMMA_NEG = 4\n",
|
| 119 |
+
"ASL_CLIP = 0.05\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"# Weight decay (DeBERTa default)\n",
|
| 122 |
+
"WEIGHT_DECAY = 0.06\n",
|
| 123 |
+
"WARMUP_RATIO = 0.1\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"SEED = 42\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"# ═══════════════════════════════════════════════════════════════\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"# CUAD 41 label names (must match class_id 0-40 in CUAD dataset)\n",
|
| 130 |
+
"CUAD_LABELS = [\n",
|
| 131 |
+
" \"Document Name\", # 0\n",
|
| 132 |
+
" \"Parties\", # 1\n",
|
| 133 |
+
" \"Agreement Date\", # 2\n",
|
| 134 |
+
" \"Effective Date\", # 3\n",
|
| 135 |
+
" \"Expiration Date\", # 4\n",
|
| 136 |
+
" \"Renewal Term\", # 5\n",
|
| 137 |
+
" \"Notice Period to Terminate Renewal\", # 6\n",
|
| 138 |
+
" \"Governing Law\", # 7\n",
|
| 139 |
+
" \"Most Favored Nation\", # 8\n",
|
| 140 |
+
" \"Non-Compete\", # 9\n",
|
| 141 |
+
" \"Exclusivity\", # 10\n",
|
| 142 |
+
" \"No-Solicit of Customers\", # 11\n",
|
| 143 |
+
" \"No-Solicit of Employees\", # 12\n",
|
| 144 |
+
" \"Non-Disparagement\", # 13\n",
|
| 145 |
+
" \"Termination for Convenience\", # 14\n",
|
| 146 |
+
" \"ROFR/ROFO/ROFN\", # 15\n",
|
| 147 |
+
" \"Change of Control\", # 16\n",
|
| 148 |
+
" \"Anti-Assignment\", # 17\n",
|
| 149 |
+
" \"Revenue/Profit Sharing\", # 18\n",
|
| 150 |
+
" \"Price Restriction\", # 19\n",
|
| 151 |
+
" \"Minimum Commitment\", # 20\n",
|
| 152 |
+
" \"Volume Restriction\", # 21\n",
|
| 153 |
+
" \"IP Ownership Assignment\", # 22\n",
|
| 154 |
+
" \"Joint IP Ownership\", # 23\n",
|
| 155 |
+
" \"License Grant\", # 24\n",
|
| 156 |
+
" \"Non-Transferable License\", # 25\n",
|
| 157 |
+
" \"Affiliate License-Licensor\", # 26\n",
|
| 158 |
+
" \"Affiliate License-Licensee\", # 27\n",
|
| 159 |
+
" \"Unlimited/All-You-Can-Eat License\", # 28\n",
|
| 160 |
+
" \"Irrevocable or Perpetual License\", # 29\n",
|
| 161 |
+
" \"Source Code Escrow\", # 30\n",
|
| 162 |
+
" \"Post-Termination Services\", # 31\n",
|
| 163 |
+
" \"Audit Rights\", # 32\n",
|
| 164 |
+
" \"Uncapped Liability\", # 33\n",
|
| 165 |
+
" \"Cap on Liability\", # 34\n",
|
| 166 |
+
" \"Liquidated Damages\", # 35\n",
|
| 167 |
+
" \"Warranty Duration\", # 36\n",
|
| 168 |
+
" \"Insurance\", # 37\n",
|
| 169 |
+
" \"Covenant Not to Sue\", # 38\n",
|
| 170 |
+
" \"Third Party Beneficiary\", # 39\n",
|
| 171 |
+
" \"Other\", # 40\n",
|
| 172 |
+
"]\n",
|
| 173 |
+
"\n",
|
| 174 |
+
"NUM_CUAD_LABELS = len(CUAD_LABELS) # 41\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"print(f\"🛡️ ClauseGuard v4 Training Configuration\")\n",
|
| 177 |
+
"print(f\" Base model: {BASE_MODEL}\")\n",
|
| 178 |
+
"print(f\" Max length: {MAX_LENGTH}\")\n",
|
| 179 |
+
"print(f\" Hub model: {HUB_MODEL_ID}\")\n",
|
| 180 |
+
"print(f\" GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
|
| 181 |
+
"print(f\" VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\" if torch.cuda.is_available() else \"\")\n",
|
| 182 |
+
"print(f\" CUAD classes: {NUM_CUAD_LABELS}\")"
|
| 183 |
+
],
|
| 184 |
+
"metadata": {},
|
| 185 |
+
"execution_count": null,
|
| 186 |
+
"outputs": []
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"cell_type": "markdown",
|
| 190 |
+
"source": [
|
| 191 |
+
"## Step 4: Load Datasets"
|
| 192 |
+
],
|
| 193 |
+
"metadata": {}
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"cell_type": "code",
|
| 197 |
+
"source": [
|
| 198 |
+
"from datasets import load_dataset, Dataset\n",
|
| 199 |
+
"import pandas as pd\n",
|
| 200 |
+
"from collections import Counter\n",
|
| 201 |
+
"\n",
|
| 202 |
+
"# ═══════════════════════════════════════════════════════════════\n",
|
| 203 |
+
"# Stage 1: LEDGAR (100 classes, single-label)\n",
|
| 204 |
+
"# ═══════════════════════════════════════════════════════════════\n",
|
| 205 |
+
"print(\"📚 Loading LEDGAR dataset...\")\n",
|
| 206 |
+
"ledgar = load_dataset(\"coastalcph/lex_glue\", \"ledgar\")\n",
|
| 207 |
+
"print(f\" Train: {len(ledgar['train']):,} | Val: {len(ledgar['validation']):,} | Test: {len(ledgar['test']):,}\")\n",
|
| 208 |
+
"num_ledgar_labels = ledgar['train'].features['label'].num_classes\n",
|
| 209 |
+
"print(f\" Classes: {num_ledgar_labels}\")\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"# ═══════════════════════════════════════════════════════════════\n",
|
| 212 |
+
"# Stage 2: CUAD (41 classes — reformulated for classification)\n",
|
| 213 |
+
"# ═══════════════════════════════════════════════════════════════\n",
|
| 214 |
+
"print(\"\\n📚 Loading CUAD classification dataset...\")\n",
|
| 215 |
+
"cuad_raw = load_dataset(\"dvgodoy/CUAD_v1_Contract_Understanding_clause_classification\", split=\"train\")\n",
|
| 216 |
+
"print(f\" Total rows: {len(cuad_raw):,}\")\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"# Analyze class distribution\n",
|
| 219 |
+
"class_counts = Counter(cuad_raw['class_id'])\n",
|
| 220 |
+
"print(f\" Unique classes: {len(class_counts)}\")\n",
|
| 221 |
+
"print(f\" \\n Class distribution:\")\n",
|
| 222 |
+
"for cid in sorted(class_counts.keys()):\n",
|
| 223 |
+
" label_name = CUAD_LABELS[cid] if cid < len(CUAD_LABELS) else f\"Unknown-{cid}\"\n",
|
| 224 |
+
" count = class_counts[cid]\n",
|
| 225 |
+
" bar = '█' * min(50, count // 10)\n",
|
| 226 |
+
" print(f\" {cid:2d} {label_name:40s} {count:5d} {bar}\")"
|
| 227 |
+
],
|
| 228 |
+
"metadata": {},
|
| 229 |
+
"execution_count": null,
|
| 230 |
+
"outputs": []
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"cell_type": "markdown",
|
| 234 |
+
"source": [
|
| 235 |
+
"## Step 5: Prepare CUAD Train/Val/Test Splits"
|
| 236 |
+
],
|
| 237 |
+
"metadata": {}
|
| 238 |
+
},
|
| 239 |
+
{
|
| 240 |
+
"cell_type": "code",
|
| 241 |
+
"source": [
|
| 242 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"# CUAD only has train split — create val/test by splitting by file_name\n",
|
| 245 |
+
"# (so no data leakage between contracts)\n",
|
| 246 |
+
"cuad_df = cuad_raw.to_pandas()\n",
|
| 247 |
+
"\n",
|
| 248 |
+
"# Get unique file names\n",
|
| 249 |
+
"unique_files = cuad_df['file_name'].unique()\n",
|
| 250 |
+
"print(f\"Unique contracts: {len(unique_files)}\")\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"# Split files 80/10/10\n",
|
| 253 |
+
"train_files, test_files = train_test_split(unique_files, test_size=0.2, random_state=SEED)\n",
|
| 254 |
+
"val_files, test_files = train_test_split(test_files, test_size=0.5, random_state=SEED)\n",
|
| 255 |
+
"\n",
|
| 256 |
+
"cuad_train_df = cuad_df[cuad_df['file_name'].isin(train_files)]\n",
|
| 257 |
+
"cuad_val_df = cuad_df[cuad_df['file_name'].isin(val_files)]\n",
|
| 258 |
+
"cuad_test_df = cuad_df[cuad_df['file_name'].isin(test_files)]\n",
|
| 259 |
+
"\n",
|
| 260 |
+
"print(f\"CUAD splits — Train: {len(cuad_train_df)} | Val: {len(cuad_val_df)} | Test: {len(cuad_test_df)}\")\n",
|
| 261 |
+
"print(f\"Train contracts: {len(train_files)} | Val contracts: {len(val_files)} | Test contracts: {len(test_files)}\")\n",
|
| 262 |
+
"\n",
|
| 263 |
+
"# Convert to HF Dataset\n",
|
| 264 |
+
"cuad_train = Dataset.from_pandas(cuad_train_df.reset_index(drop=True))\n",
|
| 265 |
+
"cuad_val = Dataset.from_pandas(cuad_val_df.reset_index(drop=True))\n",
|
| 266 |
+
"cuad_test = Dataset.from_pandas(cuad_test_df.reset_index(drop=True))\n",
|
| 267 |
+
"\n",
|
| 268 |
+
"# Verify class distribution in each split\n",
|
| 269 |
+
"for name, ds in [(\"Train\", cuad_train), (\"Val\", cuad_val), (\"Test\", cuad_test)]:\n",
|
| 270 |
+
" counts = Counter(ds['class_id'])\n",
|
| 271 |
+
" empty_classes = [i for i in range(NUM_CUAD_LABELS) if counts.get(i, 0) == 0]\n",
|
| 272 |
+
" print(f\" {name}: {len(ds)} rows, {len(counts)} classes present, {len(empty_classes)} classes missing: {empty_classes[:5]}...\")"
|
| 273 |
+
],
|
| 274 |
+
"metadata": {},
|
| 275 |
+
"execution_count": null,
|
| 276 |
+
"outputs": []
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"cell_type": "markdown",
|
| 280 |
+
"source": [
|
| 281 |
+
"## Step 6: Tokenizer & Preprocessing"
|
| 282 |
+
],
|
| 283 |
+
"metadata": {}
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"cell_type": "code",
|
| 287 |
+
"source": [
|
| 288 |
+
"from transformers import AutoTokenizer\n",
|
| 289 |
+
"\n",
|
| 290 |
+
"print(f\"Loading tokenizer: {BASE_MODEL}\")\n",
|
| 291 |
+
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n",
|
| 292 |
+
"\n",
|
| 293 |
+
"# ── LEDGAR preprocessing (single-label) ──\n",
|
| 294 |
+
"def preprocess_ledgar(examples):\n",
|
| 295 |
+
" tokenized = tokenizer(\n",
|
| 296 |
+
" examples[\"text\"],\n",
|
| 297 |
+
" truncation=True,\n",
|
| 298 |
+
" max_length=MAX_LENGTH,\n",
|
| 299 |
+
" padding=False,\n",
|
| 300 |
+
" )\n",
|
| 301 |
+
" tokenized[\"labels\"] = examples[\"label\"] # int label for CrossEntropy\n",
|
| 302 |
+
" return tokenized\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"# ── CUAD preprocessing (single-label per clause, 41 classes) ──\n",
|
| 305 |
+
"def preprocess_cuad(examples):\n",
|
| 306 |
+
" tokenized = tokenizer(\n",
|
| 307 |
+
" examples[\"clause\"],\n",
|
| 308 |
+
" truncation=True,\n",
|
| 309 |
+
" max_length=MAX_LENGTH,\n",
|
| 310 |
+
" padding=False,\n",
|
| 311 |
+
" )\n",
|
| 312 |
+
" tokenized[\"labels\"] = examples[\"class_id\"] # int label for CrossEntropy + ASL\n",
|
| 313 |
+
" return tokenized\n",
|
| 314 |
+
"\n",
|
| 315 |
+
"print(\"Tokenizing LEDGAR...\")\n",
|
| 316 |
+
"ledgar_tokenized = ledgar.map(\n",
|
| 317 |
+
" preprocess_ledgar, batched=True,\n",
|
| 318 |
+
" remove_columns=ledgar[\"train\"].column_names,\n",
|
| 319 |
+
" desc=\"Tokenizing LEDGAR\"\n",
|
| 320 |
+
")\n",
|
| 321 |
+
"\n",
|
| 322 |
+
"print(\"Tokenizing CUAD...\")\n",
|
| 323 |
+
"cuad_train_tok = cuad_train.map(\n",
|
| 324 |
+
" preprocess_cuad, batched=True,\n",
|
| 325 |
+
" remove_columns=cuad_train.column_names,\n",
|
| 326 |
+
" desc=\"Tokenizing CUAD train\"\n",
|
| 327 |
+
")\n",
|
| 328 |
+
"cuad_val_tok = cuad_val.map(\n",
|
| 329 |
+
" preprocess_cuad, batched=True,\n",
|
| 330 |
+
" remove_columns=cuad_val.column_names,\n",
|
| 331 |
+
" desc=\"Tokenizing CUAD val\"\n",
|
| 332 |
+
")\n",
|
| 333 |
+
"cuad_test_tok = cuad_test.map(\n",
|
| 334 |
+
" preprocess_cuad, batched=True,\n",
|
| 335 |
+
" remove_columns=cuad_test.column_names,\n",
|
| 336 |
+
" desc=\"Tokenizing CUAD test\"\n",
|
| 337 |
+
")\n",
|
| 338 |
+
"\n",
|
| 339 |
+
"# Check token lengths\n",
|
| 340 |
+
"train_lengths = [len(x) for x in cuad_train_tok['input_ids']]\n",
|
| 341 |
+
"print(f\"\\n📊 CUAD token length stats:\")\n",
|
| 342 |
+
"print(f\" Mean: {np.mean(train_lengths):.0f} | Median: {np.median(train_lengths):.0f}\")\n",
|
| 343 |
+
"print(f\" 95th pct: {np.percentile(train_lengths, 95):.0f} | Max: {max(train_lengths)}\")\n",
|
| 344 |
+
"print(f\" Truncated (>512): {sum(1 for l in train_lengths if l >= MAX_LENGTH)} ({sum(1 for l in train_lengths if l >= MAX_LENGTH)/len(train_lengths)*100:.1f}%)\")\n",
|
| 345 |
+
"print(\"✅ Tokenization complete!\")"
|
| 346 |
+
],
|
| 347 |
+
"metadata": {},
|
| 348 |
+
"execution_count": null,
|
| 349 |
+
"outputs": []
|
| 350 |
+
},
|
| 351 |
+
{
|
| 352 |
+
"cell_type": "markdown",
|
| 353 |
+
"source": [
|
| 354 |
+
"## Step 7: Asymmetric Loss Function\n",
|
| 355 |
+
"\n",
|
| 356 |
+
"From [Asymmetric Loss For Multi-Label Classification](https://arxiv.org/abs/2009.14119) (ICCV 2021).\n",
|
| 357 |
+
"\n",
|
| 358 |
+
"Key idea: Down-weight easy negatives more aggressively than positives. Critical for CUAD where most labels are negative for any given clause."
|
| 359 |
+
],
|
| 360 |
+
"metadata": {}
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"cell_type": "code",
|
| 364 |
+
"source": [
|
| 365 |
+
"import torch\n",
|
| 366 |
+
"import torch.nn as nn\n",
|
| 367 |
+
"import torch.nn.functional as F\n",
|
| 368 |
+
"\n",
|
| 369 |
+
"\n",
|
| 370 |
+
"class AsymmetricLoss(nn.Module):\n",
|
| 371 |
+
" \"\"\"\n",
|
| 372 |
+
" Asymmetric Loss from arxiv:2009.14119.\n",
|
| 373 |
+
" \n",
|
| 374 |
+
" For multi-class (single-label) classification with class imbalance:\n",
|
| 375 |
+
" We use the multi-class variant — apply focal-style re-weighting\n",
|
| 376 |
+
" to cross-entropy, with different gamma for correct vs incorrect classes.\n",
|
| 377 |
+
" \n",
|
| 378 |
+
" For multi-label (multi-hot) classification:\n",
|
| 379 |
+
" L+ = (1-p)^γ+ * log(p)\n",
|
| 380 |
+
" L- = (pm)^γ- * log(1-pm), pm = max(p - m, 0)\n",
|
| 381 |
+
" \"\"\"\n",
|
| 382 |
+
" def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, eps=1e-8,\n",
|
| 383 |
+
" num_classes=None, class_weights=None, mode=\"multi_class\"):\n",
|
| 384 |
+
" super().__init__()\n",
|
| 385 |
+
" self.gamma_pos = gamma_pos\n",
|
| 386 |
+
" self.gamma_neg = gamma_neg\n",
|
| 387 |
+
" self.clip = clip\n",
|
| 388 |
+
" self.eps = eps\n",
|
| 389 |
+
" self.mode = mode\n",
|
| 390 |
+
" \n",
|
| 391 |
+
" # Optional class weights for severe imbalance\n",
|
| 392 |
+
" if class_weights is not None:\n",
|
| 393 |
+
" self.register_buffer('class_weights', torch.tensor(class_weights, dtype=torch.float32))\n",
|
| 394 |
+
" else:\n",
|
| 395 |
+
" self.class_weights = None\n",
|
| 396 |
+
"\n",
|
| 397 |
+
" def forward(self, logits, targets):\n",
|
| 398 |
+
" if self.mode == \"multi_label\":\n",
|
| 399 |
+
" return self._multi_label_loss(logits, targets)\n",
|
| 400 |
+
" else:\n",
|
| 401 |
+
" return self._multi_class_loss(logits, targets)\n",
|
| 402 |
+
" \n",
|
| 403 |
+
" def _multi_class_loss(self, logits, targets):\n",
|
| 404 |
+
" \"\"\"Focal-style cross-entropy with asymmetric gamma for single-label classification.\"\"\"\n",
|
| 405 |
+
" # Standard cross-entropy with class weights\n",
|
| 406 |
+
" if self.class_weights is not None:\n",
|
| 407 |
+
" ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')\n",
|
| 408 |
+
" else:\n",
|
| 409 |
+
" ce_loss = F.cross_entropy(logits, targets, reduction='none')\n",
|
| 410 |
+
" \n",
|
| 411 |
+
" # Apply focal modulation\n",
|
| 412 |
+
" probs = F.softmax(logits, dim=-1)\n",
|
| 413 |
+
" # Get probability of the correct class\n",
|
| 414 |
+
" p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)\n",
|
| 415 |
+
" \n",
|
| 416 |
+
" # Focal weight: (1 - p_t)^gamma\n",
|
| 417 |
+
" # Use gamma_neg for hard examples (low p_t), gamma_pos for easy ones\n",
|
| 418 |
+
" focal_weight = (1 - p_t) ** self.gamma_neg\n",
|
| 419 |
+
" \n",
|
| 420 |
+
" loss = focal_weight * ce_loss\n",
|
| 421 |
+
" return loss.mean()\n",
|
| 422 |
+
"\n",
|
| 423 |
+
" def _multi_label_loss(self, logits, targets):\n",
|
| 424 |
+
" \"\"\"Full ASL for multi-label classification.\"\"\"\n",
|
| 425 |
+
" p = torch.sigmoid(logits)\n",
|
| 426 |
+
" \n",
|
| 427 |
+
" if self.clip is not None and self.clip > 0:\n",
|
| 428 |
+
" p_m = torch.clamp(p - self.clip, min=0)\n",
|
| 429 |
+
" else:\n",
|
| 430 |
+
" p_m = p\n",
|
| 431 |
+
" \n",
|
| 432 |
+
" loss_pos = targets * (1 - p) ** self.gamma_pos * torch.log(p + self.eps)\n",
|
| 433 |
+
" loss_neg = (1 - targets) * p_m ** self.gamma_neg * torch.log(1 - p_m + self.eps)\n",
|
| 434 |
+
" \n",
|
| 435 |
+
" loss = -(loss_pos + loss_neg)\n",
|
| 436 |
+
" return loss.mean()\n",
|
| 437 |
+
"\n",
|
| 438 |
+
"\n",
|
| 439 |
+
"print(\"✅ AsymmetricLoss defined\")\n",
|
| 440 |
+
"print(f\" γ+ = {ASL_GAMMA_POS}, γ- = {ASL_GAMMA_NEG}, clip = {ASL_CLIP}\")"
|
| 441 |
+
],
|
| 442 |
+
"metadata": {},
|
| 443 |
+
"execution_count": null,
|
| 444 |
+
"outputs": []
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"cell_type": "markdown",
|
| 448 |
+
"source": [
|
| 449 |
+
"## Step 8: Custom Trainer with ASL"
|
| 450 |
+
],
|
| 451 |
+
"metadata": {}
|
| 452 |
+
},
|
| 453 |
+
{
|
| 454 |
+
"cell_type": "code",
|
| 455 |
+
"source": [
|
| 456 |
+
"from transformers import Trainer\n",
|
| 457 |
+
"\n",
|
| 458 |
+
"\n",
|
| 459 |
+
"class ASLTrainer(Trainer):\n",
|
| 460 |
+
" \"\"\"Custom Trainer that uses Asymmetric Loss instead of standard CrossEntropy.\"\"\"\n",
|
| 461 |
+
" \n",
|
| 462 |
+
" def __init__(self, *args, asl_loss_fn=None, **kwargs):\n",
|
| 463 |
+
" super().__init__(*args, **kwargs)\n",
|
| 464 |
+
" self.asl = asl_loss_fn\n",
|
| 465 |
+
"\n",
|
| 466 |
+
" def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n",
|
| 467 |
+
" labels = inputs.pop(\"labels\")\n",
|
| 468 |
+
" outputs = model(**inputs)\n",
|
| 469 |
+
" logits = outputs.logits\n",
|
| 470 |
+
" \n",
|
| 471 |
+
" if self.asl is not None:\n",
|
| 472 |
+
" loss = self.asl(logits, labels)\n",
|
| 473 |
+
" else:\n",
|
| 474 |
+
" # Fallback to standard cross-entropy\n",
|
| 475 |
+
" loss = F.cross_entropy(logits, labels)\n",
|
| 476 |
+
" \n",
|
| 477 |
+
" return (loss, outputs) if return_outputs else loss\n",
|
| 478 |
+
"\n",
|
| 479 |
+
"\n",
|
| 480 |
+
"print(\"✅ ASLTrainer defined\")"
|
| 481 |
+
],
|
| 482 |
+
"metadata": {},
|
| 483 |
+
"execution_count": null,
|
| 484 |
+
"outputs": []
|
| 485 |
+
},
|
| 486 |
+
{
|
| 487 |
+
"cell_type": "markdown",
|
| 488 |
+
"source": [
|
| 489 |
+
"## Step 9: Metrics"
|
| 490 |
+
],
|
| 491 |
+
"metadata": {}
|
| 492 |
+
},
|
| 493 |
+
{
|
| 494 |
+
"cell_type": "code",
|
| 495 |
+
"source": [
|
| 496 |
+
"from sklearn.metrics import f1_score, precision_score, recall_score, classification_report\n",
|
| 497 |
+
"\n",
|
| 498 |
+
"\n",
|
| 499 |
+
"def compute_metrics_single_label(eval_pred):\n",
|
| 500 |
+
" \"\"\"Metrics for single-label classification (LEDGAR & CUAD).\"\"\"\n",
|
| 501 |
+
" logits, labels = eval_pred.predictions, eval_pred.label_ids\n",
|
| 502 |
+
" preds = np.argmax(logits, axis=-1)\n",
|
| 503 |
+
" \n",
|
| 504 |
+
" micro_f1 = f1_score(labels, preds, average=\"micro\", zero_division=0)\n",
|
| 505 |
+
" macro_f1 = f1_score(labels, preds, average=\"macro\", zero_division=0)\n",
|
| 506 |
+
" weighted_f1 = f1_score(labels, preds, average=\"weighted\", zero_division=0)\n",
|
| 507 |
+
" accuracy = (preds == labels).mean()\n",
|
| 508 |
+
" \n",
|
| 509 |
+
" return {\n",
|
| 510 |
+
" \"accuracy\": accuracy,\n",
|
| 511 |
+
" \"micro_f1\": micro_f1,\n",
|
| 512 |
+
" \"macro_f1\": macro_f1,\n",
|
| 513 |
+
" \"weighted_f1\": weighted_f1,\n",
|
| 514 |
+
" }\n",
|
| 515 |
+
"\n",
|
| 516 |
+
"\n",
|
| 517 |
+
"def compute_metrics_cuad_detailed(eval_pred):\n",
|
| 518 |
+
" \"\"\"Detailed metrics for CUAD — includes per-class F1.\"\"\"\n",
|
| 519 |
+
" logits, labels = eval_pred.predictions, eval_pred.label_ids\n",
|
| 520 |
+
" preds = np.argmax(logits, axis=-1)\n",
|
| 521 |
+
" \n",
|
| 522 |
+
" micro_f1 = f1_score(labels, preds, average=\"micro\", zero_division=0)\n",
|
| 523 |
+
" macro_f1 = f1_score(labels, preds, average=\"macro\", zero_division=0)\n",
|
| 524 |
+
" weighted_f1 = f1_score(labels, preds, average=\"weighted\", zero_division=0)\n",
|
| 525 |
+
" accuracy = (preds == labels).mean()\n",
|
| 526 |
+
" \n",
|
| 527 |
+
" # Per-class F1\n",
|
| 528 |
+
" per_class_f1 = f1_score(labels, preds, average=None, zero_division=0)\n",
|
| 529 |
+
" class_metrics = {}\n",
|
| 530 |
+
" for i, f1_val in enumerate(per_class_f1):\n",
|
| 531 |
+
" if i < len(CUAD_LABELS):\n",
|
| 532 |
+
" # Truncate label name for cleaner logging\n",
|
| 533 |
+
" safe_name = CUAD_LABELS[i][:20].replace(\" \", \"_\").replace(\"/\", \"_\")\n",
|
| 534 |
+
" class_metrics[f\"f1_{safe_name}\"] = float(f1_val)\n",
|
| 535 |
+
" \n",
|
| 536 |
+
" return {\n",
|
| 537 |
+
" \"accuracy\": accuracy,\n",
|
| 538 |
+
" \"micro_f1\": micro_f1,\n",
|
| 539 |
+
" \"macro_f1\": macro_f1,\n",
|
| 540 |
+
" \"weighted_f1\": weighted_f1,\n",
|
| 541 |
+
" **class_metrics,\n",
|
| 542 |
+
" }\n",
|
| 543 |
+
"\n",
|
| 544 |
+
"\n",
|
| 545 |
+
"print(\"✅ Metrics functions defined\")"
|
| 546 |
+
],
|
| 547 |
+
"metadata": {},
|
| 548 |
+
"execution_count": null,
|
| 549 |
+
"outputs": []
|
| 550 |
+
},
|
| 551 |
+
{
|
| 552 |
+
"cell_type": "markdown",
|
| 553 |
+
"source": [
|
| 554 |
+
"---\n",
|
| 555 |
+
"# 🏋️ STAGE 1: Pre-fine-tune on LEDGAR\n",
|
| 556 |
+
"\n",
|
| 557 |
+
"**Goal:** Teach DeBERTa-v3-large what types of contract clauses exist (100 classes, ~60K examples).\n",
|
| 558 |
+
"\n",
|
| 559 |
+
"This stage uses standard cross-entropy loss since LEDGAR is well-balanced.\n",
|
| 560 |
+
"\n",
|
| 561 |
+
"**Expected:** ~85-90% micro-F1 after 3-5 epochs (~3-5 hours on T4, ~1-2 hours on A100)"
|
| 562 |
+
],
|
| 563 |
+
"metadata": {}
|
| 564 |
+
},
|
| 565 |
+
{
|
| 566 |
+
"cell_type": "code",
|
| 567 |
+
"source": [
|
| 568 |
+
"from transformers import (\n",
|
| 569 |
+
" AutoConfig,\n",
|
| 570 |
+
" AutoModelForSequenceClassification,\n",
|
| 571 |
+
" TrainingArguments,\n",
|
| 572 |
+
" DataCollatorWithPadding,\n",
|
| 573 |
+
" EarlyStoppingCallback,\n",
|
| 574 |
+
")\n",
|
| 575 |
+
"\n",
|
| 576 |
+
"print(f\"🏋️ STAGE 1: Pre-fine-tune on LEDGAR ({num_ledgar_labels} classes)\")\n",
|
| 577 |
+
"print(f\" Loading {BASE_MODEL}...\")\n",
|
| 578 |
+
"\n",
|
| 579 |
+
"# Load model for Stage 1 (100 classes, single-label)\n",
|
| 580 |
+
"stage1_model = AutoModelForSequenceClassification.from_pretrained(\n",
|
| 581 |
+
" BASE_MODEL,\n",
|
| 582 |
+
" num_labels=num_ledgar_labels,\n",
|
| 583 |
+
" problem_type=\"single_label_classification\",\n",
|
| 584 |
+
" ignore_mismatched_sizes=True,\n",
|
| 585 |
+
")\n",
|
| 586 |
+
"\n",
|
| 587 |
+
"total_params = sum(p.numel() for p in stage1_model.parameters())\n",
|
| 588 |
+
"trainable_params = sum(p.numel() for p in stage1_model.parameters() if p.requires_grad)\n",
|
| 589 |
+
"print(f\" Total parameters: {total_params:,}\")\n",
|
| 590 |
+
"print(f\" Trainable parameters: {trainable_params:,}\")\n",
|
| 591 |
+
"\n",
|
| 592 |
+
"stage1_args = TrainingArguments(\n",
|
| 593 |
+
" output_dir=\"./stage1_ledgar\",\n",
|
| 594 |
+
" num_train_epochs=STAGE1_EPOCHS,\n",
|
| 595 |
+
" per_device_train_batch_size=STAGE1_BATCH,\n",
|
| 596 |
+
" per_device_eval_batch_size=4,\n",
|
| 597 |
+
" gradient_accumulation_steps=STAGE1_GRAD_ACCUM,\n",
|
| 598 |
+
" learning_rate=STAGE1_LR,\n",
|
| 599 |
+
" weight_decay=WEIGHT_DECAY,\n",
|
| 600 |
+
" warmup_ratio=WARMUP_RATIO,\n",
|
| 601 |
+
" lr_scheduler_type=\"cosine\",\n",
|
| 602 |
+
" eval_strategy=\"epoch\",\n",
|
| 603 |
+
" save_strategy=\"epoch\",\n",
|
| 604 |
+
" save_total_limit=2,\n",
|
| 605 |
+
" load_best_model_at_end=True,\n",
|
| 606 |
+
" metric_for_best_model=\"macro_f1\",\n",
|
| 607 |
+
" greater_is_better=True,\n",
|
| 608 |
+
" bf16=False, # DeBERTa-v3 breaks with fp16 gradient scaler; fp32 is safest on T4\n",
|
| 609 |
+
" fp16=False,\n",
|
| 610 |
+
" logging_strategy=\"steps\",\n",
|
| 611 |
+
" logging_steps=50,\n",
|
| 612 |
+
" logging_first_step=True,\n",
|
| 613 |
+
" disable_tqdm=False,\n",
|
| 614 |
+
" report_to=\"none\",\n",
|
| 615 |
+
" dataloader_num_workers=2,\n",
|
| 616 |
+
" seed=SEED,\n",
|
| 617 |
+
" gradient_checkpointing=True, # Critical for T4 (16GB VRAM)\n",
|
| 618 |
+
")\n",
|
| 619 |
+
"\n",
|
| 620 |
+
"stage1_trainer = Trainer(\n",
|
| 621 |
+
" model=stage1_model,\n",
|
| 622 |
+
" args=stage1_args,\n",
|
| 623 |
+
" train_dataset=ledgar_tokenized[\"train\"],\n",
|
| 624 |
+
" eval_dataset=ledgar_tokenized[\"validation\"],\n",
|
| 625 |
+
" processing_class=tokenizer,\n",
|
| 626 |
+
" data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
|
| 627 |
+
" compute_metrics=compute_metrics_single_label,\n",
|
| 628 |
+
" callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],\n",
|
| 629 |
+
")\n",
|
| 630 |
+
"\n",
|
| 631 |
+
"print(\"\\n🚀 Starting Stage 1 training...\")\n",
|
| 632 |
+
"stage1_result = stage1_trainer.train()\n",
|
| 633 |
+
"print(f\"\\n✅ Stage 1 complete! Loss: {stage1_result.training_loss:.4f}\")"
|
| 634 |
+
],
|
| 635 |
+
"metadata": {},
|
| 636 |
+
"execution_count": null,
|
| 637 |
+
"outputs": []
|
| 638 |
+
},
|
| 639 |
+
{
|
| 640 |
+
"cell_type": "code",
|
| 641 |
+
"source": [
|
| 642 |
+
"# Evaluate Stage 1 on LEDGAR test set\n",
|
| 643 |
+
"print(\"📊 Stage 1 — LEDGAR Test Evaluation\")\n",
|
| 644 |
+
"stage1_test = stage1_trainer.evaluate(ledgar_tokenized[\"test\"])\n",
|
| 645 |
+
"print(f\" Accuracy: {stage1_test['eval_accuracy']:.4f}\")\n",
|
| 646 |
+
"print(f\" Micro-F1: {stage1_test['eval_micro_f1']:.4f}\")\n",
|
| 647 |
+
"print(f\" Macro-F1: {stage1_test['eval_macro_f1']:.4f}\")\n",
|
| 648 |
+
"print(f\" Weighted-F1: {stage1_test['eval_weighted_f1']:.4f}\")\n",
|
| 649 |
+
"\n",
|
| 650 |
+
"# Save Stage 1 checkpoint\n",
|
| 651 |
+
"STAGE1_CHECKPOINT = \"./stage1_ledgar_best\"\n",
|
| 652 |
+
"stage1_trainer.save_model(STAGE1_CHECKPOINT)\n",
|
| 653 |
+
"tokenizer.save_pretrained(STAGE1_CHECKPOINT)\n",
|
| 654 |
+
"print(f\"\\n💾 Stage 1 checkpoint saved to {STAGE1_CHECKPOINT}\")"
|
| 655 |
+
],
|
| 656 |
+
"metadata": {},
|
| 657 |
+
"execution_count": null,
|
| 658 |
+
"outputs": []
|
| 659 |
+
},
|
| 660 |
+
{
|
| 661 |
+
"cell_type": "markdown",
|
| 662 |
+
"source": [
|
| 663 |
+
"---\n",
|
| 664 |
+
"# 🏋️ STAGE 2: Fine-tune on CUAD 41-class with Asymmetric Loss\n",
|
| 665 |
+
"\n",
|
| 666 |
+
"**Goal:** Learn the 41 CUAD contract clause types from the Stage 1 backbone.\n",
|
| 667 |
+
"\n",
|
| 668 |
+
"Key improvements over current ClauseGuard:\n",
|
| 669 |
+
"- DeBERTa-v3-large backbone pre-trained on LEDGAR (Stage 1)\n",
|
| 670 |
+
"- 512 tokens (vs 256) — captures full clause content\n",
|
| 671 |
+
"- Asymmetric Loss for class imbalance\n",
|
| 672 |
+
"- Full fine-tuning (no LoRA bottleneck)\n",
|
| 673 |
+
"\n",
|
| 674 |
+
"**Expected:** 75-87% macro-F1 after 10-20 epochs (~5-8 hours on T4, ~2-4 hours on A100)"
|
| 675 |
+
],
|
| 676 |
+
"metadata": {}
|
| 677 |
+
},
|
| 678 |
+
{
|
| 679 |
+
"cell_type": "code",
|
| 680 |
+
"source": [
|
| 681 |
+
"# Free Stage 1 model memory before loading Stage 2\n",
|
| 682 |
+
"del stage1_model, stage1_trainer\n",
|
| 683 |
+
"torch.cuda.empty_cache()\n",
|
| 684 |
+
"import gc; gc.collect()\n",
|
| 685 |
+
"\n",
|
| 686 |
+
"print(f\"🏋️ STAGE 2: Fine-tune on CUAD ({NUM_CUAD_LABELS} classes) with ASL\")\n",
|
| 687 |
+
"\n",
|
| 688 |
+
"# Load Stage 1 checkpoint with new head (100 → 41 classes)\n",
|
| 689 |
+
"stage2_model = AutoModelForSequenceClassification.from_pretrained(\n",
|
| 690 |
+
" STAGE1_CHECKPOINT,\n",
|
| 691 |
+
" num_labels=NUM_CUAD_LABELS,\n",
|
| 692 |
+
" ignore_mismatched_sizes=True, # classifier head: 100 → 41\n",
|
| 693 |
+
" problem_type=\"single_label_classification\",\n",
|
| 694 |
+
")\n",
|
| 695 |
+
"\n",
|
| 696 |
+
"print(f\" Loaded Stage 1 backbone with new {NUM_CUAD_LABELS}-class head\")\n",
|
| 697 |
+
"print(f\" Parameters: {sum(p.numel() for p in stage2_model.parameters()):,}\")\n",
|
| 698 |
+
"\n",
|
| 699 |
+
"# Compute class weights from training distribution\n",
|
| 700 |
+
"train_class_counts = Counter(cuad_train_tok['labels'])\n",
|
| 701 |
+
"total_samples = sum(train_class_counts.values())\n",
|
| 702 |
+
"class_weights = []\n",
|
| 703 |
+
"for i in range(NUM_CUAD_LABELS):\n",
|
| 704 |
+
" count = train_class_counts.get(i, 1) # avoid div by zero\n",
|
| 705 |
+
" # Inverse frequency weighting, capped\n",
|
| 706 |
+
" weight = min(10.0, total_samples / (NUM_CUAD_LABELS * count))\n",
|
| 707 |
+
" class_weights.append(weight)\n",
|
| 708 |
+
"\n",
|
| 709 |
+
"print(f\" Class weight range: [{min(class_weights):.2f}, {max(class_weights):.2f}]\")\n",
|
| 710 |
+
"\n",
|
| 711 |
+
"# Create ASL loss\n",
|
| 712 |
+
"asl_loss = AsymmetricLoss(\n",
|
| 713 |
+
" gamma_pos=ASL_GAMMA_POS,\n",
|
| 714 |
+
" gamma_neg=ASL_GAMMA_NEG,\n",
|
| 715 |
+
" clip=ASL_CLIP,\n",
|
| 716 |
+
" num_classes=NUM_CUAD_LABELS,\n",
|
| 717 |
+
" class_weights=class_weights,\n",
|
| 718 |
+
" mode=\"multi_class\", # single-label per clause\n",
|
| 719 |
+
")\n",
|
| 720 |
+
"# Move to GPU\n",
|
| 721 |
+
"if torch.cuda.is_available():\n",
|
| 722 |
+
" asl_loss = asl_loss.cuda()\n",
|
| 723 |
+
"\n",
|
| 724 |
+
"stage2_args = TrainingArguments(\n",
|
| 725 |
+
" output_dir=\"./stage2_cuad\",\n",
|
| 726 |
+
" num_train_epochs=STAGE2_EPOCHS,\n",
|
| 727 |
+
" per_device_train_batch_size=STAGE2_BATCH,\n",
|
| 728 |
+
" per_device_eval_batch_size=4,\n",
|
| 729 |
+
" gradient_accumulation_steps=STAGE2_GRAD_ACCUM,\n",
|
| 730 |
+
" learning_rate=STAGE2_LR,\n",
|
| 731 |
+
" weight_decay=WEIGHT_DECAY,\n",
|
| 732 |
+
" warmup_ratio=WARMUP_RATIO,\n",
|
| 733 |
+
" lr_scheduler_type=\"cosine\",\n",
|
| 734 |
+
" eval_strategy=\"epoch\",\n",
|
| 735 |
+
" save_strategy=\"epoch\",\n",
|
| 736 |
+
" save_total_limit=3,\n",
|
| 737 |
+
" load_best_model_at_end=True,\n",
|
| 738 |
+
" metric_for_best_model=\"macro_f1\",\n",
|
| 739 |
+
" greater_is_better=True,\n",
|
| 740 |
+
" bf16=False, # DeBERTa-v3 breaks with fp16 gradient scaler; fp32 is safest on T4\n",
|
| 741 |
+
" fp16=False,\n",
|
| 742 |
+
" logging_strategy=\"steps\",\n",
|
| 743 |
+
" logging_steps=25,\n",
|
| 744 |
+
" logging_first_step=True,\n",
|
| 745 |
+
" disable_tqdm=False,\n",
|
| 746 |
+
" report_to=\"none\",\n",
|
| 747 |
+
" push_to_hub=True,\n",
|
| 748 |
+
" hub_model_id=HUB_MODEL_ID,\n",
|
| 749 |
+
" dataloader_num_workers=2,\n",
|
| 750 |
+
" seed=SEED,\n",
|
| 751 |
+
" gradient_checkpointing=True,\n",
|
| 752 |
+
")\n",
|
| 753 |
+
"\n",
|
| 754 |
+
"stage2_trainer = ASLTrainer(\n",
|
| 755 |
+
" model=stage2_model,\n",
|
| 756 |
+
" args=stage2_args,\n",
|
| 757 |
+
" asl_loss_fn=asl_loss,\n",
|
| 758 |
+
" train_dataset=cuad_train_tok,\n",
|
| 759 |
+
" eval_dataset=cuad_val_tok,\n",
|
| 760 |
+
" processing_class=tokenizer,\n",
|
| 761 |
+
" data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
|
| 762 |
+
" compute_metrics=compute_metrics_cuad_detailed,\n",
|
| 763 |
+
" callbacks=[EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE)],\n",
|
| 764 |
+
")\n",
|
| 765 |
+
"\n",
|
| 766 |
+
"print(\"\\n🚀 Starting Stage 2 training with Asymmetric Loss...\")\n",
|
| 767 |
+
"stage2_result = stage2_trainer.train()\n",
|
| 768 |
+
"print(f\"\\n✅ Stage 2 complete! Loss: {stage2_result.training_loss:.4f}\")"
|
| 769 |
+
],
|
| 770 |
+
"metadata": {},
|
| 771 |
+
"execution_count": null,
|
| 772 |
+
"outputs": []
|
| 773 |
+
},
|
| 774 |
+
{
|
| 775 |
+
"cell_type": "markdown",
|
| 776 |
+
"source": [
|
| 777 |
+
"## Step 10: Evaluate Stage 2 on CUAD Test Set"
|
| 778 |
+
],
|
| 779 |
+
"metadata": {}
|
| 780 |
+
},
|
| 781 |
+
{
|
| 782 |
+
"cell_type": "code",
|
| 783 |
+
"source": [
|
| 784 |
+
"print(\"📊 Stage 2 — CUAD Test Evaluation\")\n",
|
| 785 |
+
"test_results = stage2_trainer.evaluate(cuad_test_tok)\n",
|
| 786 |
+
"\n",
|
| 787 |
+
"print(f\"\\n{'='*60}\")\n",
|
| 788 |
+
"print(f\" CUAD TEST RESULTS (DeBERTa-v3-large + LEDGAR + ASL)\")\n",
|
| 789 |
+
"print(f\"{'='*60}\")\n",
|
| 790 |
+
"print(f\" Accuracy: {test_results['eval_accuracy']:.4f}\")\n",
|
| 791 |
+
"print(f\" Micro-F1: {test_results['eval_micro_f1']:.4f}\")\n",
|
| 792 |
+
"print(f\" Macro-F1: {test_results['eval_macro_f1']:.4f}\")\n",
|
| 793 |
+
"print(f\" Weighted-F1: {test_results['eval_weighted_f1']:.4f}\")\n",
|
| 794 |
+
"print(f\"{'='*60}\")\n",
|
| 795 |
+
"\n",
|
| 796 |
+
"# Per-class F1 report\n",
|
| 797 |
+
"print(f\"\\n Per-class F1 scores:\")\n",
|
| 798 |
+
"print(f\" {'Class':<42s} {'F1':>6s}\")\n",
|
| 799 |
+
"print(f\" {'-'*48}\")\n",
|
| 800 |
+
"\n",
|
| 801 |
+
"zero_f1_classes = []\n",
|
| 802 |
+
"for i, label_name in enumerate(CUAD_LABELS):\n",
|
| 803 |
+
" safe_name = label_name[:20].replace(\" \", \"_\").replace(\"/\", \"_\")\n",
|
| 804 |
+
" key = f\"eval_f1_{safe_name}\"\n",
|
| 805 |
+
" f1_val = test_results.get(key, 0.0)\n",
|
| 806 |
+
" bar = '█' * int(f1_val * 30)\n",
|
| 807 |
+
" status = \"\" if f1_val > 0 else \" ← ZERO\"\n",
|
| 808 |
+
" print(f\" {i:2d} {label_name:<40s} {f1_val:.4f} {bar}{status}\")\n",
|
| 809 |
+
" if f1_val == 0:\n",
|
| 810 |
+
" zero_f1_classes.append(label_name)\n",
|
| 811 |
+
"\n",
|
| 812 |
+
"print(f\"\\n Classes with zero F1: {len(zero_f1_classes)}\")\n",
|
| 813 |
+
"if zero_f1_classes:\n",
|
| 814 |
+
" for c in zero_f1_classes:\n",
|
| 815 |
+
" print(f\" ⚠️ {c}\")"
|
| 816 |
+
],
|
| 817 |
+
"metadata": {},
|
| 818 |
+
"execution_count": null,
|
| 819 |
+
"outputs": []
|
| 820 |
+
},
|
| 821 |
+
{
|
| 822 |
+
"cell_type": "markdown",
|
| 823 |
+
"source": [
|
| 824 |
+
"## Step 11: Full Classification Report"
|
| 825 |
+
],
|
| 826 |
+
"metadata": {}
|
| 827 |
+
},
|
| 828 |
+
{
|
| 829 |
+
"cell_type": "code",
|
| 830 |
+
"source": [
|
| 831 |
+
"# Generate full sklearn classification report\n",
|
| 832 |
+
"from sklearn.metrics import classification_report\n",
|
| 833 |
+
"\n",
|
| 834 |
+
"# Get predictions on test set\n",
|
| 835 |
+
"preds_output = stage2_trainer.predict(cuad_test_tok)\n",
|
| 836 |
+
"preds = np.argmax(preds_output.predictions, axis=-1)\n",
|
| 837 |
+
"labels = preds_output.label_ids\n",
|
| 838 |
+
"\n",
|
| 839 |
+
"# Only include labels that appear in test set\n",
|
| 840 |
+
"present_labels = sorted(set(labels) | set(preds))\n",
|
| 841 |
+
"target_names = [CUAD_LABELS[i] if i < len(CUAD_LABELS) else f\"Class-{i}\" for i in present_labels]\n",
|
| 842 |
+
"\n",
|
| 843 |
+
"report = classification_report(\n",
|
| 844 |
+
" labels, preds,\n",
|
| 845 |
+
" labels=present_labels,\n",
|
| 846 |
+
" target_names=target_names,\n",
|
| 847 |
+
" zero_division=0,\n",
|
| 848 |
+
" digits=4,\n",
|
| 849 |
+
")\n",
|
| 850 |
+
"print(\"\\n📊 Full Classification Report:\")\n",
|
| 851 |
+
"print(report)"
|
| 852 |
+
],
|
| 853 |
+
"metadata": {},
|
| 854 |
+
"execution_count": null,
|
| 855 |
+
"outputs": []
|
| 856 |
+
},
|
| 857 |
+
{
|
| 858 |
+
"cell_type": "markdown",
|
| 859 |
+
"source": [
|
| 860 |
+
"## Step 12: Push Final Model to Hub"
|
| 861 |
+
],
|
| 862 |
+
"metadata": {}
|
| 863 |
+
},
|
| 864 |
+
{
|
| 865 |
+
"cell_type": "code",
|
| 866 |
+
"source": [
|
| 867 |
+
"# Save model with proper label mapping\n",
|
| 868 |
+
"stage2_model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)}\n",
|
| 869 |
+
"stage2_model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)}\n",
|
| 870 |
+
"\n",
|
| 871 |
+
"# Save locally\n",
|
| 872 |
+
"FINAL_DIR = \"./clauseguard-deberta-final\"\n",
|
| 873 |
+
"stage2_trainer.save_model(FINAL_DIR)\n",
|
| 874 |
+
"tokenizer.save_pretrained(FINAL_DIR)\n",
|
| 875 |
+
"\n",
|
| 876 |
+
"# Push to Hub\n",
|
| 877 |
+
"print(f\"\\n☁️ Pushing model to Hub: {HUB_MODEL_ID}\")\n",
|
| 878 |
+
"stage2_trainer.push_to_hub(\n",
|
| 879 |
+
" commit_message=(\n",
|
| 880 |
+
" f\"ClauseGuard v4: DeBERTa-v3-large 2-stage (LEDGAR→CUAD) with ASL\\n\"\n",
|
| 881 |
+
" f\"CUAD Test: micro-F1={test_results['eval_micro_f1']:.4f}, \"\n",
|
| 882 |
+
" f\"macro-F1={test_results['eval_macro_f1']:.4f}\"\n",
|
| 883 |
+
" )\n",
|
| 884 |
+
")\n",
|
| 885 |
+
"\n",
|
| 886 |
+
"print(f\"\\n✅ Model pushed to: https://huggingface.co/{HUB_MODEL_ID}\")"
|
| 887 |
+
],
|
| 888 |
+
"metadata": {},
|
| 889 |
+
"execution_count": null,
|
| 890 |
+
"outputs": []
|
| 891 |
+
},
|
| 892 |
+
{
|
| 893 |
+
"cell_type": "markdown",
|
| 894 |
+
"source": [
|
| 895 |
+
"## Step 13: Test the Model on Sample Clauses"
|
| 896 |
+
],
|
| 897 |
+
"metadata": {}
|
| 898 |
+
},
|
| 899 |
+
{
|
| 900 |
+
"cell_type": "code",
|
| 901 |
+
"source": [
|
| 902 |
+
"from transformers import pipeline as hf_pipeline\n",
|
| 903 |
+
"\n",
|
| 904 |
+
"# Load the trained model for inference\n",
|
| 905 |
+
"classifier = hf_pipeline(\n",
|
| 906 |
+
" \"text-classification\",\n",
|
| 907 |
+
" model=stage2_model,\n",
|
| 908 |
+
" tokenizer=tokenizer,\n",
|
| 909 |
+
" top_k=5, # return top 5 predictions\n",
|
| 910 |
+
" device=0 if torch.cuda.is_available() else -1,\n",
|
| 911 |
+
")\n",
|
| 912 |
+
"\n",
|
| 913 |
+
"test_clauses = [\n",
|
| 914 |
+
" # High-risk clauses\n",
|
| 915 |
+
" \"The Company may terminate this Agreement at any time, with or without cause, upon written notice to the other party.\",\n",
|
| 916 |
+
" \"In no event shall the Company be liable for any indirect, incidental, special, or consequential damages arising out of this Agreement.\",\n",
|
| 917 |
+
" \"All intellectual property developed during the term of this Agreement shall be owned exclusively by the Company.\",\n",
|
| 918 |
+
" \"This Agreement shall be governed by and construed in accordance with the laws of the State of Delaware.\",\n",
|
| 919 |
+
" \"Any disputes arising out of this Agreement shall be resolved through binding arbitration in New York.\",\n",
|
| 920 |
+
" \"The Employee agrees not to compete with the Company for a period of two (2) years following termination.\",\n",
|
| 921 |
+
" # Neutral clauses\n",
|
| 922 |
+
" \"This Agreement shall be effective as of January 1, 2024.\",\n",
|
| 923 |
+
" \"The initial term of this Agreement shall be three (3) years.\",\n",
|
| 924 |
+
" \"Either party may assign this Agreement with the prior written consent of the other party.\",\n",
|
| 925 |
+
"]\n",
|
| 926 |
+
"\n",
|
| 927 |
+
"print(\"🧪 Testing model on sample clauses:\\n\")\n",
|
| 928 |
+
"for clause in test_clauses:\n",
|
| 929 |
+
" results = classifier(clause, truncation=True, max_length=MAX_LENGTH)\n",
|
| 930 |
+
" top = results[0] if isinstance(results[0], dict) else results[0][0]\n",
|
| 931 |
+
" top3 = results[:3] if isinstance(results[0], dict) else results[0][:3]\n",
|
| 932 |
+
" \n",
|
| 933 |
+
" print(f\"📄 \\\"{clause[:90]}{'...' if len(clause) > 90 else ''}\\\"\")\n",
|
| 934 |
+
" for r in top3:\n",
|
| 935 |
+
" score = r['score']\n",
|
| 936 |
+
" bar = '█' * int(score * 20)\n",
|
| 937 |
+
" print(f\" → {r['label']:40s} {score:.4f} {bar}\")\n",
|
| 938 |
+
" print()"
|
| 939 |
+
],
|
| 940 |
+
"metadata": {},
|
| 941 |
+
"execution_count": null,
|
| 942 |
+
"outputs": []
|
| 943 |
+
},
|
| 944 |
+
{
|
| 945 |
+
"cell_type": "markdown",
|
| 946 |
+
"source": [
|
| 947 |
+
"## Step 14: Generate Updated app.py Integration Code\n",
|
| 948 |
+
"\n",
|
| 949 |
+
"Copy-paste this into your ClauseGuard Space's `app.py` to use the new model."
|
| 950 |
+
],
|
| 951 |
+
"metadata": {}
|
| 952 |
+
},
|
| 953 |
+
{
|
| 954 |
+
"cell_type": "code",
|
| 955 |
+
"source": [
|
| 956 |
+
"integration_code = f'''\n",
|
| 957 |
+
"# ═══════════════════════════════════════════════════════════════\n",
|
| 958 |
+
"# ClauseGuard v4 — Integration Code\n",
|
| 959 |
+
"# Replace the model loading section in app.py with this:\n",
|
| 960 |
+
"# ═══════════════════════════════════════════════════════════════\n",
|
| 961 |
+
"\n",
|
| 962 |
+
"# OLD (remove these):\n",
|
| 963 |
+
"# base = \"nlpaueb/legal-bert-base-uncased\"\n",
|
| 964 |
+
"# adapter = \"Mokshith31/legalbert-contract-clause-classification\"\n",
|
| 965 |
+
"# from peft import PeftModel\n",
|
| 966 |
+
"\n",
|
| 967 |
+
"# NEW:\n",
|
| 968 |
+
"CLAUSEGUARD_MODEL = \"{HUB_MODEL_ID}\"\n",
|
| 969 |
+
"\n",
|
| 970 |
+
"def _load_cuad_model():\n",
|
| 971 |
+
" global cuad_tokenizer, cuad_model, _model_status\n",
|
| 972 |
+
" if not _HAS_TORCH:\n",
|
| 973 |
+
" _model_status[\"cuad\"] = \"unavailable\"\n",
|
| 974 |
+
" return\n",
|
| 975 |
+
" try:\n",
|
| 976 |
+
" print(f\"[ClauseGuard] Loading classifier: {{CLAUSEGUARD_MODEL}}\")\n",
|
| 977 |
+
" cuad_tokenizer = AutoTokenizer.from_pretrained(CLAUSEGUARD_MODEL)\n",
|
| 978 |
+
" cuad_model = AutoModelForSequenceClassification.from_pretrained(CLAUSEGUARD_MODEL)\n",
|
| 979 |
+
" cuad_model.eval()\n",
|
| 980 |
+
" _model_status[\"cuad\"] = \"loaded\"\n",
|
| 981 |
+
" print(f\"[ClauseGuard] Model loaded: {{sum(p.numel() for p in cuad_model.parameters()):,}} params\")\n",
|
| 982 |
+
" except Exception as e:\n",
|
| 983 |
+
" print(f\"[ClauseGuard] Model load failed: {{e}}\")\n",
|
| 984 |
+
" _model_status[\"cuad\"] = f\"failed: {{e}}\"\n",
|
| 985 |
+
"\n",
|
| 986 |
+
"# In classify_cuad(), change max_length:\n",
|
| 987 |
+
"# max_length=256 → max_length=512\n",
|
| 988 |
+
"#\n",
|
| 989 |
+
"# Also: since the new model is single-label (softmax),\n",
|
| 990 |
+
"# change the prediction logic from sigmoid to:\n",
|
| 991 |
+
"#\n",
|
| 992 |
+
"# probs = torch.softmax(logits, dim=-1)[0] # instead of sigmoid\n",
|
| 993 |
+
"# top_indices = torch.argsort(probs, descending=True)[:5]\n",
|
| 994 |
+
"# for i in top_indices:\n",
|
| 995 |
+
"# if float(probs[i]) > 0.10: # confidence threshold\n",
|
| 996 |
+
"# label = CUAD_LABELS[i]\n",
|
| 997 |
+
"# ...\n",
|
| 998 |
+
"\n",
|
| 999 |
+
"# No more PEFT dependency needed!\n",
|
| 1000 |
+
"# No more ignore_mismatched_sizes!\n",
|
| 1001 |
+
"# Just load directly — the model already has the correct head.\n",
|
| 1002 |
+
"'''\n",
|
| 1003 |
+
"\n",
|
| 1004 |
+
"print(integration_code)"
|
| 1005 |
+
],
|
| 1006 |
+
"metadata": {},
|
| 1007 |
+
"execution_count": null,
|
| 1008 |
+
"outputs": []
|
| 1009 |
+
},
|
| 1010 |
+
{
|
| 1011 |
+
"cell_type": "markdown",
|
| 1012 |
+
"source": [
|
| 1013 |
+
"## Step 15: Comparison with Current Model\n",
|
| 1014 |
+
"\n",
|
| 1015 |
+
"| Metric | Current (Legal-BERT + LoRA) | New (DeBERTa-v3-large + ASL) |\n",
|
| 1016 |
+
"|--------|---------------------------|-----------------------------|\n",
|
| 1017 |
+
"| Base model | 110M params | 435M params |\n",
|
| 1018 |
+
"| Training | LoRA (frozen backbone) | Full fine-tune |\n",
|
| 1019 |
+
"| Pre-training | None | LEDGAR (60K, 100 classes) |\n",
|
| 1020 |
+
"| Max tokens | 256 | 512 |\n",
|
| 1021 |
+
"| Loss function | Cross-entropy | Asymmetric Loss |\n",
|
| 1022 |
+
"| Zero-F1 classes | 10 of 41 | TBD (should be much fewer) |\n",
|
| 1023 |
+
"| Macro-F1 | ~50% | Target: 78-87% |\n",
|
| 1024 |
+
"\n",
|
| 1025 |
+
"---\n",
|
| 1026 |
+
"\n",
|
| 1027 |
+
"## ✅ Done!\n",
|
| 1028 |
+
"\n",
|
| 1029 |
+
"Your trained model is at: **https://huggingface.co/gaurv007/clauseguard-deberta-v3-large**\n",
|
| 1030 |
+
"\n",
|
| 1031 |
+
"### Next Steps:\n",
|
| 1032 |
+
"1. Update ClauseGuard Space to use this model (see integration code above)\n",
|
| 1033 |
+
"2. Remove PEFT dependency from requirements.txt\n",
|
| 1034 |
+
"3. Consider training SetFit classifiers for any remaining zero-F1 classes\n",
|
| 1035 |
+
"4. Add OCR support (Feature #2)\n",
|
| 1036 |
+
"5. Add RAG chatbot (Feature #3)"
|
| 1037 |
+
],
|
| 1038 |
+
"metadata": {}
|
| 1039 |
+
}
|
| 1040 |
+
]
|
| 1041 |
+
}
|
ml/requirements.txt
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
transformers=
|
| 2 |
datasets>=3.2.0
|
| 3 |
torch>=2.5.0
|
| 4 |
scikit-learn>=1.6.0
|
| 5 |
accelerate>=1.2.0
|
| 6 |
-
|
|
|
|
| 1 |
+
transformers>=5.6.0
|
| 2 |
datasets>=3.2.0
|
| 3 |
torch>=2.5.0
|
| 4 |
scikit-learn>=1.6.0
|
| 5 |
accelerate>=1.2.0
|
| 6 |
+
huggingface_hub>=0.27.0
|
ml/train_classifier_v4.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ClauseGuard v4 — 2-Stage DeBERTa-v3-large Training Script
|
| 3 |
+
═══════════════════════════════════════════════════════════
|
| 4 |
+
|
| 5 |
+
Stage 1: Pre-fine-tune on LEDGAR (60K legal provisions, 100 classes)
|
| 6 |
+
Stage 2: Fine-tune on CUAD (41 classes) with Asymmetric Loss
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python train_classifier_v4.py # Full 2-stage pipeline
|
| 10 |
+
python train_classifier_v4.py --stage 1 # Stage 1 only
|
| 11 |
+
python train_classifier_v4.py --stage 2 --checkpoint ./stage1_ledgar_best # Stage 2 only
|
| 12 |
+
|
| 13 |
+
Requirements:
|
| 14 |
+
pip install transformers datasets scikit-learn accelerate torch
|
| 15 |
+
|
| 16 |
+
Hardware: A100 80GB recommended (~4-6 hours total)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import gc
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
from collections import Counter
|
| 24 |
+
from datetime import datetime
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
from datasets import load_dataset, Dataset
|
| 31 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
|
| 32 |
+
from sklearn.model_selection import train_test_split
|
| 33 |
+
from transformers import (
|
| 34 |
+
AutoConfig,
|
| 35 |
+
AutoModelForSequenceClassification,
|
| 36 |
+
AutoTokenizer,
|
| 37 |
+
DataCollatorWithPadding,
|
| 38 |
+
Trainer,
|
| 39 |
+
TrainingArguments,
|
| 40 |
+
EarlyStoppingCallback,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ═══════════════════════════════════════════════════════════════
|
| 45 |
+
# CONFIGURATION
|
| 46 |
+
# ═══════════════════════════════════════════════════════════════
|
| 47 |
+
|
| 48 |
+
BASE_MODEL = os.environ.get("BASE_MODEL", "microsoft/deberta-v3-large")
|
| 49 |
+
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "512"))
|
| 50 |
+
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-deberta-v3-large")
|
| 51 |
+
PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true"
|
| 52 |
+
SEED = 42
|
| 53 |
+
|
| 54 |
+
CUAD_LABELS = [
|
| 55 |
+
"Document Name", "Parties", "Agreement Date", "Effective Date",
|
| 56 |
+
"Expiration Date", "Renewal Term", "Notice Period to Terminate Renewal",
|
| 57 |
+
"Governing Law", "Most Favored Nation", "Non-Compete", "Exclusivity",
|
| 58 |
+
"No-Solicit of Customers", "No-Solicit of Employees", "Non-Disparagement",
|
| 59 |
+
"Termination for Convenience", "ROFR/ROFO/ROFN", "Change of Control",
|
| 60 |
+
"Anti-Assignment", "Revenue/Profit Sharing", "Price Restriction",
|
| 61 |
+
"Minimum Commitment", "Volume Restriction", "IP Ownership Assignment",
|
| 62 |
+
"Joint IP Ownership", "License Grant", "Non-Transferable License",
|
| 63 |
+
"Affiliate License-Licensor", "Affiliate License-Licensee",
|
| 64 |
+
"Unlimited/All-You-Can-Eat License", "Irrevocable or Perpetual License",
|
| 65 |
+
"Source Code Escrow", "Post-Termination Services", "Audit Rights",
|
| 66 |
+
"Uncapped Liability", "Cap on Liability", "Liquidated Damages",
|
| 67 |
+
"Warranty Duration", "Insurance", "Covenant Not to Sue",
|
| 68 |
+
"Third Party Beneficiary", "Other",
|
| 69 |
+
]
|
| 70 |
+
NUM_CUAD_LABELS = len(CUAD_LABELS)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ═══════════════════════════════════════════════════════════════
|
| 74 |
+
# ASYMMETRIC LOSS (arxiv:2009.14119)
|
| 75 |
+
# ═══════════════════════════════════════════════════════════════
|
| 76 |
+
|
| 77 |
+
class AsymmetricLoss(nn.Module):
|
| 78 |
+
"""Focal-style loss with asymmetric gamma for class imbalance."""
|
| 79 |
+
|
| 80 |
+
def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, eps=1e-8,
|
| 81 |
+
class_weights=None):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.gamma_pos = gamma_pos
|
| 84 |
+
self.gamma_neg = gamma_neg
|
| 85 |
+
self.clip = clip
|
| 86 |
+
self.eps = eps
|
| 87 |
+
if class_weights is not None:
|
| 88 |
+
self.register_buffer('class_weights',
|
| 89 |
+
torch.tensor(class_weights, dtype=torch.float32))
|
| 90 |
+
else:
|
| 91 |
+
self.class_weights = None
|
| 92 |
+
|
| 93 |
+
def forward(self, logits, targets):
|
| 94 |
+
"""Multi-class focal cross-entropy with class weights."""
|
| 95 |
+
if self.class_weights is not None:
|
| 96 |
+
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights,
|
| 97 |
+
reduction='none')
|
| 98 |
+
else:
|
| 99 |
+
ce_loss = F.cross_entropy(logits, targets, reduction='none')
|
| 100 |
+
|
| 101 |
+
probs = F.softmax(logits, dim=-1)
|
| 102 |
+
p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
|
| 103 |
+
focal_weight = (1 - p_t) ** self.gamma_neg
|
| 104 |
+
loss = focal_weight * ce_loss
|
| 105 |
+
return loss.mean()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ═══════════════════════════════════════════════════════════════
|
| 109 |
+
# CUSTOM TRAINER
|
| 110 |
+
# ═══════════════════════════════════════════════════════════════
|
| 111 |
+
|
| 112 |
+
class ASLTrainer(Trainer):
|
| 113 |
+
def __init__(self, *args, asl_loss_fn=None, **kwargs):
|
| 114 |
+
super().__init__(*args, **kwargs)
|
| 115 |
+
self.asl = asl_loss_fn
|
| 116 |
+
|
| 117 |
+
def compute_loss(self, model, inputs, return_outputs=False,
|
| 118 |
+
num_items_in_batch=None):
|
| 119 |
+
labels = inputs.pop("labels")
|
| 120 |
+
outputs = model(**inputs)
|
| 121 |
+
logits = outputs.logits
|
| 122 |
+
if self.asl is not None:
|
| 123 |
+
loss = self.asl(logits, labels)
|
| 124 |
+
else:
|
| 125 |
+
loss = F.cross_entropy(logits, labels)
|
| 126 |
+
return (loss, outputs) if return_outputs else loss
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ═══════════════════════════════════════════════════════════════
|
| 130 |
+
# METRICS
|
| 131 |
+
# ═══════════════════════════════════════════════════════════════
|
| 132 |
+
|
| 133 |
+
def compute_metrics(eval_pred):
|
| 134 |
+
logits, labels = eval_pred.predictions, eval_pred.label_ids
|
| 135 |
+
preds = np.argmax(logits, axis=-1)
|
| 136 |
+
return {
|
| 137 |
+
"accuracy": (preds == labels).mean(),
|
| 138 |
+
"micro_f1": f1_score(labels, preds, average="micro", zero_division=0),
|
| 139 |
+
"macro_f1": f1_score(labels, preds, average="macro", zero_division=0),
|
| 140 |
+
"weighted_f1": f1_score(labels, preds, average="weighted", zero_division=0),
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ═══════════════════════════════════════════════════════════════
|
| 145 |
+
# STAGE 1: LEDGAR
|
| 146 |
+
# ═══════════════════════════════════════════════════════════════
|
| 147 |
+
|
| 148 |
+
def run_stage1(tokenizer, output_dir="./stage1_ledgar_best"):
|
| 149 |
+
print("\n" + "=" * 60)
|
| 150 |
+
print(" STAGE 1: Pre-fine-tune on LEDGAR (100 classes)")
|
| 151 |
+
print("=" * 60)
|
| 152 |
+
|
| 153 |
+
ledgar = load_dataset("coastalcph/lex_glue", "ledgar")
|
| 154 |
+
num_labels = ledgar['train'].features['label'].num_classes
|
| 155 |
+
print(f" Train: {len(ledgar['train']):,} | Val: {len(ledgar['validation']):,}")
|
| 156 |
+
print(f" Classes: {num_labels}")
|
| 157 |
+
|
| 158 |
+
def preprocess(examples):
|
| 159 |
+
tok = tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH,
|
| 160 |
+
padding=False)
|
| 161 |
+
tok["labels"] = examples["label"]
|
| 162 |
+
return tok
|
| 163 |
+
|
| 164 |
+
tokenized = ledgar.map(preprocess, batched=True,
|
| 165 |
+
remove_columns=ledgar["train"].column_names)
|
| 166 |
+
|
| 167 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 168 |
+
BASE_MODEL, num_labels=num_labels,
|
| 169 |
+
problem_type="single_label_classification",
|
| 170 |
+
ignore_mismatched_sizes=True,
|
| 171 |
+
)
|
| 172 |
+
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 173 |
+
|
| 174 |
+
args = TrainingArguments(
|
| 175 |
+
output_dir="./stage1_ledgar",
|
| 176 |
+
num_train_epochs=5,
|
| 177 |
+
per_device_train_batch_size=8,
|
| 178 |
+
per_device_eval_batch_size=16,
|
| 179 |
+
gradient_accumulation_steps=4,
|
| 180 |
+
learning_rate=2e-5,
|
| 181 |
+
weight_decay=0.06,
|
| 182 |
+
warmup_ratio=0.1,
|
| 183 |
+
lr_scheduler_type="cosine",
|
| 184 |
+
eval_strategy="epoch",
|
| 185 |
+
save_strategy="epoch",
|
| 186 |
+
save_total_limit=2,
|
| 187 |
+
load_best_model_at_end=True,
|
| 188 |
+
metric_for_best_model="macro_f1",
|
| 189 |
+
greater_is_better=True,
|
| 190 |
+
bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,
|
| 191 |
+
fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
|
| 192 |
+
logging_strategy="steps",
|
| 193 |
+
logging_steps=50,
|
| 194 |
+
logging_first_step=True,
|
| 195 |
+
disable_tqdm=True,
|
| 196 |
+
report_to="none",
|
| 197 |
+
dataloader_num_workers=2,
|
| 198 |
+
seed=SEED,
|
| 199 |
+
gradient_checkpointing=True,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
trainer = Trainer(
|
| 203 |
+
model=model, args=args,
|
| 204 |
+
train_dataset=tokenized["train"],
|
| 205 |
+
eval_dataset=tokenized["validation"],
|
| 206 |
+
processing_class=tokenizer,
|
| 207 |
+
data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
|
| 208 |
+
compute_metrics=compute_metrics,
|
| 209 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
result = trainer.train()
|
| 213 |
+
print(f"\n Stage 1 training loss: {result.training_loss:.4f}")
|
| 214 |
+
|
| 215 |
+
test_metrics = trainer.evaluate(tokenized["test"])
|
| 216 |
+
print(f" Stage 1 test micro-F1: {test_metrics['eval_micro_f1']:.4f}")
|
| 217 |
+
print(f" Stage 1 test macro-F1: {test_metrics['eval_macro_f1']:.4f}")
|
| 218 |
+
|
| 219 |
+
trainer.save_model(output_dir)
|
| 220 |
+
tokenizer.save_pretrained(output_dir)
|
| 221 |
+
print(f" Saved to {output_dir}")
|
| 222 |
+
|
| 223 |
+
del model, trainer
|
| 224 |
+
torch.cuda.empty_cache()
|
| 225 |
+
gc.collect()
|
| 226 |
+
|
| 227 |
+
return output_dir
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ═════════════════════════════════════════════════════════��═════
|
| 231 |
+
# STAGE 2: CUAD
|
| 232 |
+
# ═══════════════════════════════════════════════════════════════
|
| 233 |
+
|
| 234 |
+
def run_stage2(tokenizer, checkpoint_path, output_dir="./clauseguard-deberta-final"):
|
| 235 |
+
print("\n" + "=" * 60)
|
| 236 |
+
print(f" STAGE 2: Fine-tune on CUAD ({NUM_CUAD_LABELS} classes) with ASL")
|
| 237 |
+
print("=" * 60)
|
| 238 |
+
|
| 239 |
+
# Load and split CUAD
|
| 240 |
+
cuad_raw = load_dataset(
|
| 241 |
+
"dvgodoy/CUAD_v1_Contract_Understanding_clause_classification",
|
| 242 |
+
split="train"
|
| 243 |
+
)
|
| 244 |
+
cuad_df = cuad_raw.to_pandas()
|
| 245 |
+
|
| 246 |
+
unique_files = cuad_df['file_name'].unique()
|
| 247 |
+
train_files, test_files = train_test_split(unique_files, test_size=0.2,
|
| 248 |
+
random_state=SEED)
|
| 249 |
+
val_files, test_files = train_test_split(test_files, test_size=0.5,
|
| 250 |
+
random_state=SEED)
|
| 251 |
+
|
| 252 |
+
splits = {
|
| 253 |
+
"train": Dataset.from_pandas(
|
| 254 |
+
cuad_df[cuad_df['file_name'].isin(train_files)].reset_index(drop=True)
|
| 255 |
+
),
|
| 256 |
+
"val": Dataset.from_pandas(
|
| 257 |
+
cuad_df[cuad_df['file_name'].isin(val_files)].reset_index(drop=True)
|
| 258 |
+
),
|
| 259 |
+
"test": Dataset.from_pandas(
|
| 260 |
+
cuad_df[cuad_df['file_name'].isin(test_files)].reset_index(drop=True)
|
| 261 |
+
),
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
for name, ds in splits.items():
|
| 265 |
+
print(f" {name}: {len(ds)} rows")
|
| 266 |
+
|
| 267 |
+
def preprocess_cuad(examples):
|
| 268 |
+
tok = tokenizer(examples["clause"], truncation=True, max_length=MAX_LENGTH,
|
| 269 |
+
padding=False)
|
| 270 |
+
tok["labels"] = examples["class_id"]
|
| 271 |
+
return tok
|
| 272 |
+
|
| 273 |
+
tok_splits = {}
|
| 274 |
+
for name, ds in splits.items():
|
| 275 |
+
tok_splits[name] = ds.map(preprocess_cuad, batched=True,
|
| 276 |
+
remove_columns=ds.column_names)
|
| 277 |
+
|
| 278 |
+
# Load model from Stage 1 checkpoint
|
| 279 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 280 |
+
checkpoint_path,
|
| 281 |
+
num_labels=NUM_CUAD_LABELS,
|
| 282 |
+
ignore_mismatched_sizes=True,
|
| 283 |
+
problem_type="single_label_classification",
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Update label mapping
|
| 287 |
+
model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)}
|
| 288 |
+
model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)}
|
| 289 |
+
|
| 290 |
+
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 291 |
+
|
| 292 |
+
# Compute class weights
|
| 293 |
+
train_counts = Counter(tok_splits["train"]["labels"])
|
| 294 |
+
total = sum(train_counts.values())
|
| 295 |
+
class_weights = []
|
| 296 |
+
for i in range(NUM_CUAD_LABELS):
|
| 297 |
+
count = train_counts.get(i, 1)
|
| 298 |
+
weight = min(10.0, total / (NUM_CUAD_LABELS * count))
|
| 299 |
+
class_weights.append(weight)
|
| 300 |
+
|
| 301 |
+
asl = AsymmetricLoss(gamma_pos=0, gamma_neg=4, clip=0.05,
|
| 302 |
+
class_weights=class_weights)
|
| 303 |
+
if torch.cuda.is_available():
|
| 304 |
+
asl = asl.cuda()
|
| 305 |
+
|
| 306 |
+
args = TrainingArguments(
|
| 307 |
+
output_dir="./stage2_cuad",
|
| 308 |
+
num_train_epochs=20,
|
| 309 |
+
per_device_train_batch_size=8,
|
| 310 |
+
per_device_eval_batch_size=16,
|
| 311 |
+
gradient_accumulation_steps=4,
|
| 312 |
+
learning_rate=1e-5,
|
| 313 |
+
weight_decay=0.06,
|
| 314 |
+
warmup_ratio=0.1,
|
| 315 |
+
lr_scheduler_type="cosine",
|
| 316 |
+
eval_strategy="epoch",
|
| 317 |
+
save_strategy="epoch",
|
| 318 |
+
save_total_limit=3,
|
| 319 |
+
load_best_model_at_end=True,
|
| 320 |
+
metric_for_best_model="macro_f1",
|
| 321 |
+
greater_is_better=True,
|
| 322 |
+
bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,
|
| 323 |
+
fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
|
| 324 |
+
logging_strategy="steps",
|
| 325 |
+
logging_steps=25,
|
| 326 |
+
logging_first_step=True,
|
| 327 |
+
disable_tqdm=True,
|
| 328 |
+
report_to="none",
|
| 329 |
+
push_to_hub=PUSH_TO_HUB,
|
| 330 |
+
hub_model_id=HUB_MODEL_ID if PUSH_TO_HUB else None,
|
| 331 |
+
dataloader_num_workers=2,
|
| 332 |
+
seed=SEED,
|
| 333 |
+
gradient_checkpointing=True,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
trainer = ASLTrainer(
|
| 337 |
+
model=model, args=args,
|
| 338 |
+
asl_loss_fn=asl,
|
| 339 |
+
train_dataset=tok_splits["train"],
|
| 340 |
+
eval_dataset=tok_splits["val"],
|
| 341 |
+
processing_class=tokenizer,
|
| 342 |
+
data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
|
| 343 |
+
compute_metrics=compute_metrics,
|
| 344 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
result = trainer.train()
|
| 348 |
+
print(f"\n Stage 2 training loss: {result.training_loss:.4f}")
|
| 349 |
+
|
| 350 |
+
# Evaluate
|
| 351 |
+
test_metrics = trainer.evaluate(tok_splits["test"])
|
| 352 |
+
print(f"\n{'='*60}")
|
| 353 |
+
print(f" CUAD TEST RESULTS")
|
| 354 |
+
print(f"{'='*60}")
|
| 355 |
+
print(f" Accuracy: {test_metrics['eval_accuracy']:.4f}")
|
| 356 |
+
print(f" Micro-F1: {test_metrics['eval_micro_f1']:.4f}")
|
| 357 |
+
print(f" Macro-F1: {test_metrics['eval_macro_f1']:.4f}")
|
| 358 |
+
print(f" Weighted-F1: {test_metrics['eval_weighted_f1']:.4f}")
|
| 359 |
+
|
| 360 |
+
# Full report
|
| 361 |
+
preds_out = trainer.predict(tok_splits["test"])
|
| 362 |
+
preds = np.argmax(preds_out.predictions, axis=-1)
|
| 363 |
+
labels = preds_out.label_ids
|
| 364 |
+
present = sorted(set(labels) | set(preds))
|
| 365 |
+
names = [CUAD_LABELS[i] if i < len(CUAD_LABELS) else f"Class-{i}" for i in present]
|
| 366 |
+
print("\n" + classification_report(labels, preds, labels=present,
|
| 367 |
+
target_names=names, zero_division=0, digits=4))
|
| 368 |
+
|
| 369 |
+
# Save
|
| 370 |
+
trainer.save_model(output_dir)
|
| 371 |
+
tokenizer.save_pretrained(output_dir)
|
| 372 |
+
|
| 373 |
+
if PUSH_TO_HUB:
|
| 374 |
+
trainer.push_to_hub(
|
| 375 |
+
commit_message=(
|
| 376 |
+
f"ClauseGuard v4: DeBERTa-v3-large LEDGAR→CUAD + ASL | "
|
| 377 |
+
f"micro-F1={test_metrics['eval_micro_f1']:.4f} "
|
| 378 |
+
f"macro-F1={test_metrics['eval_macro_f1']:.4f}"
|
| 379 |
+
)
|
| 380 |
+
)
|
| 381 |
+
print(f"\n Pushed to https://huggingface.co/{HUB_MODEL_ID}")
|
| 382 |
+
|
| 383 |
+
# Save test results
|
| 384 |
+
results_path = os.path.join(output_dir, "test_results.json")
|
| 385 |
+
with open(results_path, "w") as f:
|
| 386 |
+
json.dump({
|
| 387 |
+
"model": HUB_MODEL_ID,
|
| 388 |
+
"base_model": BASE_MODEL,
|
| 389 |
+
"max_length": MAX_LENGTH,
|
| 390 |
+
"stage1_dataset": "coastalcph/lex_glue (ledgar)",
|
| 391 |
+
"stage2_dataset": "dvgodoy/CUAD_v1_Contract_Understanding_clause_classification",
|
| 392 |
+
"test_results": {k: float(v) for k, v in test_metrics.items()
|
| 393 |
+
if isinstance(v, (int, float))},
|
| 394 |
+
"timestamp": datetime.now().isoformat(),
|
| 395 |
+
}, f, indent=2)
|
| 396 |
+
|
| 397 |
+
return output_dir
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# ═══════════════════════════════════════════════════════════════
|
| 401 |
+
# MAIN
|
| 402 |
+
# ═══════════════════════════════════════════════════════════════
|
| 403 |
+
|
| 404 |
+
def main():
|
| 405 |
+
parser = argparse.ArgumentParser(description="ClauseGuard v4 Training")
|
| 406 |
+
parser.add_argument("--stage", type=int, default=0,
|
| 407 |
+
help="Run specific stage (1 or 2). Default: both")
|
| 408 |
+
parser.add_argument("--checkpoint", type=str, default="./stage1_ledgar_best",
|
| 409 |
+
help="Stage 1 checkpoint path for Stage 2")
|
| 410 |
+
args = parser.parse_args()
|
| 411 |
+
|
| 412 |
+
print(f"🛡️ ClauseGuard v4 Training")
|
| 413 |
+
print(f" Model: {BASE_MODEL}")
|
| 414 |
+
print(f" Max length: {MAX_LENGTH}")
|
| 415 |
+
print(f" Hub: {HUB_MODEL_ID}")
|
| 416 |
+
if torch.cuda.is_available():
|
| 417 |
+
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
| 418 |
+
print(f" VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
|
| 419 |
+
|
| 420 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 421 |
+
|
| 422 |
+
if args.stage in (0, 1):
|
| 423 |
+
checkpoint = run_stage1(tokenizer)
|
| 424 |
+
else:
|
| 425 |
+
checkpoint = args.checkpoint
|
| 426 |
+
|
| 427 |
+
if args.stage in (0, 2):
|
| 428 |
+
run_stage2(tokenizer, checkpoint)
|
| 429 |
+
|
| 430 |
+
print("\n✅ Training complete!")
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
if __name__ == "__main__":
|
| 434 |
+
main()
|
obligations.py
CHANGED
|
@@ -120,18 +120,22 @@ def extract_obligations(text):
|
|
| 120 |
if not found_types:
|
| 121 |
continue
|
| 122 |
|
| 123 |
-
# Extract party
|
| 124 |
party = "Unknown"
|
| 125 |
-
|
| 126 |
-
m = re.search(pp, sentence)
|
| 127 |
-
if m:
|
| 128 |
-
party = m.group(0).strip()
|
| 129 |
-
break
|
| 130 |
-
|
| 131 |
-
# Try to determine which party has the obligation based on sentence structure
|
| 132 |
obligation_direction = _detect_obligation_direction(sentence)
|
| 133 |
if obligation_direction:
|
| 134 |
party = obligation_direction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
# Extract timeframe
|
| 137 |
deadline = "Not specified"
|
|
|
|
| 120 |
if not found_types:
|
| 121 |
continue
|
| 122 |
|
| 123 |
+
# Extract party (Fix 8: scope to sentence only, reject >40 char strings)
|
| 124 |
party = "Unknown"
|
| 125 |
+
# First try structured direction detection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
obligation_direction = _detect_obligation_direction(sentence)
|
| 127 |
if obligation_direction:
|
| 128 |
party = obligation_direction
|
| 129 |
+
else:
|
| 130 |
+
# Fallback to pattern matching within the sentence
|
| 131 |
+
for pp in PARTY_PATTERNS:
|
| 132 |
+
m = re.search(pp, sentence)
|
| 133 |
+
if m:
|
| 134 |
+
candidate = m.group(0).strip()
|
| 135 |
+
# Fix 8: Reject party strings >40 chars (header bleed-through)
|
| 136 |
+
if len(candidate) <= 40:
|
| 137 |
+
party = candidate
|
| 138 |
+
break
|
| 139 |
|
| 140 |
# Extract timeframe
|
| 141 |
deadline = "Not specified"
|
ocr_engine.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ClauseGuard — OCR Engine v1.0
|
| 3 |
+
═════════════════════════════
|
| 4 |
+
Smart PDF Router: detects native vs scanned PDFs.
|
| 5 |
+
• Native PDF → pdfplumber (fast, existing)
|
| 6 |
+
• Scanned PDF → docTR OCR (CPU-friendly, ~150MB models)
|
| 7 |
+
|
| 8 |
+
Architecture:
|
| 9 |
+
PDF uploaded
|
| 10 |
+
↓
|
| 11 |
+
[detect_if_scanned] — pdfplumber gets <50 chars/page?
|
| 12 |
+
↓ ↓
|
| 13 |
+
Native PDF Scanned PDF
|
| 14 |
+
↓ ↓
|
| 15 |
+
pdfplumber docTR OCR (CPU)
|
| 16 |
+
↓ ↓
|
| 17 |
+
Contract text → existing analysis pipeline
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import re
|
| 22 |
+
|
| 23 |
+
# ── docTR (soft-fail) ───────────────────────────────────────────────
|
| 24 |
+
_HAS_DOCTR = False
|
| 25 |
+
_ocr_predictor = None
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from doctr.io import DocumentFile
|
| 29 |
+
from doctr.models import ocr_predictor as _make_predictor
|
| 30 |
+
_HAS_DOCTR = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
# ── pdfplumber (soft-fail) ──────────────────────────────────────────
|
| 35 |
+
try:
|
| 36 |
+
import pdfplumber
|
| 37 |
+
_HAS_PDF = True
|
| 38 |
+
except ImportError:
|
| 39 |
+
_HAS_PDF = False
|
| 40 |
+
|
| 41 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 42 |
+
# OCR MODEL LOADING
|
| 43 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 44 |
+
|
| 45 |
+
_ocr_status = "not_loaded"
|
| 46 |
+
|
| 47 |
+
def _load_ocr_model():
|
| 48 |
+
"""Load docTR OCR predictor (lazy, on first use)."""
|
| 49 |
+
global _ocr_predictor, _ocr_status
|
| 50 |
+
if _ocr_predictor is not None:
|
| 51 |
+
return _ocr_predictor
|
| 52 |
+
if not _HAS_DOCTR:
|
| 53 |
+
_ocr_status = "unavailable (python-doctr not installed)"
|
| 54 |
+
return None
|
| 55 |
+
try:
|
| 56 |
+
print("[ClauseGuard OCR] Loading docTR models (fast_base + crnn_vgg16_bn)...")
|
| 57 |
+
_ocr_predictor = _make_predictor(
|
| 58 |
+
det_arch="fast_base",
|
| 59 |
+
reco_arch="crnn_vgg16_bn",
|
| 60 |
+
pretrained=True,
|
| 61 |
+
assume_straight_pages=True,
|
| 62 |
+
)
|
| 63 |
+
_ocr_status = "loaded"
|
| 64 |
+
print("[ClauseGuard OCR] docTR models loaded successfully")
|
| 65 |
+
return _ocr_predictor
|
| 66 |
+
except Exception as e:
|
| 67 |
+
_ocr_status = f"failed: {e}"
|
| 68 |
+
print(f"[ClauseGuard OCR] docTR load failed: {e}")
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_ocr_status():
|
| 73 |
+
"""Return human-readable OCR engine status."""
|
| 74 |
+
if _ocr_predictor is not None:
|
| 75 |
+
return "✅ OCR: docTR loaded"
|
| 76 |
+
elif _HAS_DOCTR:
|
| 77 |
+
return "⏳ OCR: docTR available (not yet loaded)"
|
| 78 |
+
else:
|
| 79 |
+
return "❌ OCR: unavailable (python-doctr not installed)"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 83 |
+
# SMART PDF ROUTER
|
| 84 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 85 |
+
|
| 86 |
+
def _is_scanned_pdf(file_path, min_chars_per_page=50):
|
| 87 |
+
"""
|
| 88 |
+
Detect if a PDF is scanned (image-based) by checking if pdfplumber
|
| 89 |
+
extracts fewer than `min_chars_per_page` characters on average.
|
| 90 |
+
"""
|
| 91 |
+
if not _HAS_PDF:
|
| 92 |
+
return True # Can't check with pdfplumber, assume scanned
|
| 93 |
+
try:
|
| 94 |
+
with pdfplumber.open(file_path) as pdf:
|
| 95 |
+
if len(pdf.pages) == 0:
|
| 96 |
+
return True
|
| 97 |
+
total_chars = 0
|
| 98 |
+
pages_checked = min(len(pdf.pages), 5) # Check first 5 pages
|
| 99 |
+
for i in range(pages_checked):
|
| 100 |
+
page_text = pdf.pages[i].extract_text() or ""
|
| 101 |
+
total_chars += len(page_text.strip())
|
| 102 |
+
avg_chars = total_chars / pages_checked
|
| 103 |
+
return avg_chars < min_chars_per_page
|
| 104 |
+
except Exception:
|
| 105 |
+
return True # If pdfplumber fails, try OCR
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _extract_native_pdf(file_path):
|
| 109 |
+
"""Extract text from a native (digital) PDF using pdfplumber."""
|
| 110 |
+
if not _HAS_PDF:
|
| 111 |
+
return None, "pdfplumber not installed"
|
| 112 |
+
try:
|
| 113 |
+
text = ""
|
| 114 |
+
with pdfplumber.open(file_path) as pdf:
|
| 115 |
+
for page in pdf.pages:
|
| 116 |
+
page_text = page.extract_text()
|
| 117 |
+
if page_text:
|
| 118 |
+
text += page_text + "\n\n"
|
| 119 |
+
if not text.strip():
|
| 120 |
+
return None, "No text extracted from PDF"
|
| 121 |
+
return text.strip(), None
|
| 122 |
+
except Exception as e:
|
| 123 |
+
return None, f"PDF parse error: {e}"
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _extract_scanned_pdf(file_path):
|
| 127 |
+
"""Extract text from a scanned PDF using docTR OCR."""
|
| 128 |
+
predictor = _load_ocr_model()
|
| 129 |
+
if predictor is None:
|
| 130 |
+
return None, (
|
| 131 |
+
"OCR is not available. Install python-doctr: "
|
| 132 |
+
"`pip install python-doctr[torch]`"
|
| 133 |
+
)
|
| 134 |
+
try:
|
| 135 |
+
doc = DocumentFile.from_pdf(file_path)
|
| 136 |
+
result = predictor(doc)
|
| 137 |
+
|
| 138 |
+
# Extract text page by page
|
| 139 |
+
full_text = ""
|
| 140 |
+
for page_idx, page in enumerate(result.pages):
|
| 141 |
+
page_text = ""
|
| 142 |
+
for block in page.blocks:
|
| 143 |
+
for line in block.lines:
|
| 144 |
+
line_text = " ".join(word.value for word in line.words)
|
| 145 |
+
page_text += line_text + "\n"
|
| 146 |
+
page_text += "\n"
|
| 147 |
+
full_text += page_text + "\n\n"
|
| 148 |
+
|
| 149 |
+
if not full_text.strip():
|
| 150 |
+
return None, "OCR could not extract text from scanned PDF"
|
| 151 |
+
|
| 152 |
+
# Clean up OCR artifacts
|
| 153 |
+
full_text = _clean_ocr_text(full_text)
|
| 154 |
+
return full_text.strip(), None
|
| 155 |
+
except Exception as e:
|
| 156 |
+
return None, f"OCR error: {e}"
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _clean_ocr_text(text):
|
| 160 |
+
"""Clean common OCR artifacts."""
|
| 161 |
+
# Remove excessive whitespace
|
| 162 |
+
text = re.sub(r'[ \t]{3,}', ' ', text)
|
| 163 |
+
# Fix common OCR substitutions
|
| 164 |
+
text = re.sub(r'\bl\b(?=[A-Z])', 'I', text) # l before capital → I
|
| 165 |
+
# Normalize line breaks
|
| 166 |
+
text = re.sub(r'\n{4,}', '\n\n\n', text)
|
| 167 |
+
# Remove single-char lines (OCR noise)
|
| 168 |
+
lines = text.split('\n')
|
| 169 |
+
cleaned_lines = []
|
| 170 |
+
for line in lines:
|
| 171 |
+
stripped = line.strip()
|
| 172 |
+
if len(stripped) <= 1 and stripped not in ('', '.', ',', ';'):
|
| 173 |
+
continue
|
| 174 |
+
cleaned_lines.append(line)
|
| 175 |
+
return '\n'.join(cleaned_lines)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 179 |
+
# PUBLIC API
|
| 180 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 181 |
+
|
| 182 |
+
def parse_pdf_smart(file_path):
|
| 183 |
+
"""
|
| 184 |
+
Smart PDF parser with OCR fallback.
|
| 185 |
+
|
| 186 |
+
Returns: (text, error, method)
|
| 187 |
+
text: extracted text (or None)
|
| 188 |
+
error: error message (or None)
|
| 189 |
+
method: "native" | "ocr" | None
|
| 190 |
+
"""
|
| 191 |
+
if not os.path.exists(file_path):
|
| 192 |
+
return None, "File not found", None
|
| 193 |
+
|
| 194 |
+
# Step 1: Check if PDF is scanned
|
| 195 |
+
is_scanned = _is_scanned_pdf(file_path)
|
| 196 |
+
|
| 197 |
+
if not is_scanned:
|
| 198 |
+
# Step 2a: Native PDF — use pdfplumber
|
| 199 |
+
text, error = _extract_native_pdf(file_path)
|
| 200 |
+
if text:
|
| 201 |
+
return text, None, "native"
|
| 202 |
+
# If pdfplumber returns empty, fall through to OCR
|
| 203 |
+
print("[ClauseGuard OCR] pdfplumber returned empty — falling back to OCR")
|
| 204 |
+
|
| 205 |
+
# Step 2b: Scanned PDF or pdfplumber failed — use OCR
|
| 206 |
+
print(f"[ClauseGuard OCR] {'Scanned' if is_scanned else 'Empty native'} PDF detected — running docTR OCR...")
|
| 207 |
+
text, error = _extract_scanned_pdf(file_path)
|
| 208 |
+
if text:
|
| 209 |
+
return text, None, "ocr"
|
| 210 |
+
return None, error, None
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def ocr_extract(file_path):
|
| 214 |
+
"""
|
| 215 |
+
Force OCR extraction on a PDF (bypass native text check).
|
| 216 |
+
Useful when user explicitly wants OCR.
|
| 217 |
+
"""
|
| 218 |
+
return _extract_scanned_pdf(file_path)
|
redlining.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ClauseGuard — Clause Redlining Engine v1.0
|
| 3 |
+
═══════════════════════════════════════════
|
| 4 |
+
3-Tier Hybrid Architecture:
|
| 5 |
+
Tier 1 — Template lookup (instant, zero hallucination risk)
|
| 6 |
+
Tier 2 — RAG retrieval from clause corpus (find fairer precedents)
|
| 7 |
+
Tier 3 — LLM refinement (adapt template using retrieved precedents)
|
| 8 |
+
|
| 9 |
+
Anti-hallucination guardrails:
|
| 10 |
+
• Template anchor: LLM can only refine, not generate from scratch
|
| 11 |
+
• RAG grounding: Retrieved precedents constrain the output space
|
| 12 |
+
• Disclaimer: "Not legal advice. Consult an attorney before executing."
|
| 13 |
+
• Legal citation: Prompt requires LLM to cite the consumer protection standard applied
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
|
| 20 |
+
# ── HF Inference Client (soft-fail) ─────────────────────────────────
|
| 21 |
+
_HAS_INFERENCE = False
|
| 22 |
+
try:
|
| 23 |
+
from huggingface_hub import InferenceClient
|
| 24 |
+
_HAS_INFERENCE = True
|
| 25 |
+
except ImportError:
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 29 |
+
# TIER 1: TEMPLATE LIBRARY (18+ clause types)
|
| 30 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 31 |
+
# Based on FTC guidelines, EU Directive 93/13, and CFPB guidance.
|
| 32 |
+
|
| 33 |
+
SAFE_ALTERNATIVES = {
|
| 34 |
+
# ── CRITICAL Risk Clauses ──────────────────────────────────────
|
| 35 |
+
"Uncapped Liability": {
|
| 36 |
+
"risky_pattern": "Total liability shall not exceed $1 / unlimited liability exposure",
|
| 37 |
+
"safe_alternative": (
|
| 38 |
+
"Provider's aggregate liability under this Agreement shall not exceed the total "
|
| 39 |
+
"fees paid by the Customer in the twelve (12) months preceding the claim. "
|
| 40 |
+
"This limitation shall not apply to: (a) gross negligence or willful misconduct, "
|
| 41 |
+
"(b) breach of confidentiality obligations, (c) intellectual property indemnification "
|
| 42 |
+
"obligations, or (d) violations of applicable law."
|
| 43 |
+
),
|
| 44 |
+
"legal_basis": "UCC § 2-719; Restatement (Second) of Contracts § 356",
|
| 45 |
+
"consumer_standard": "FTC guidelines on unconscionable contract terms",
|
| 46 |
+
"risk_level": "CRITICAL",
|
| 47 |
+
},
|
| 48 |
+
"Arbitration": {
|
| 49 |
+
"risky_pattern": "All disputes via binding arbitration / class action waiver",
|
| 50 |
+
"safe_alternative": (
|
| 51 |
+
"Disputes involving claims under [Dollar Amount] shall be resolved in small claims "
|
| 52 |
+
"court in the consumer's jurisdiction of residence. For other disputes, either party "
|
| 53 |
+
"may elect binding arbitration under [AAA/JAMS] rules. The consumer may opt out of "
|
| 54 |
+
"arbitration by providing written notice within thirty (30) days of accepting these "
|
| 55 |
+
"terms. Each party bears its own arbitration costs; the prevailing party may recover "
|
| 56 |
+
"reasonable attorney's fees."
|
| 57 |
+
),
|
| 58 |
+
"legal_basis": "Federal Arbitration Act § 2; AT&T Mobility v. Concepcion, 563 U.S. 333 (2011)",
|
| 59 |
+
"consumer_standard": "CFPB Arbitration Rule guidance; EU Directive 93/13/EEC Art. 3",
|
| 60 |
+
"risk_level": "CRITICAL",
|
| 61 |
+
},
|
| 62 |
+
"IP Ownership Assignment": {
|
| 63 |
+
"risky_pattern": "All IP rights assigned to company / work-for-hire everything",
|
| 64 |
+
"safe_alternative": (
|
| 65 |
+
"Intellectual property created by the Receiving Party specifically in performance of "
|
| 66 |
+
"this Agreement ('Work Product IP') shall be assigned to the Disclosing Party. "
|
| 67 |
+
"Pre-existing IP and general knowledge, skills, and experience of the Receiving Party "
|
| 68 |
+
"remain the Receiving Party's property. The Disclosing Party grants the Receiving Party "
|
| 69 |
+
"a non-exclusive, perpetual license to use Work Product IP for internal portfolio and "
|
| 70 |
+
"reference purposes."
|
| 71 |
+
),
|
| 72 |
+
"legal_basis": "17 U.S.C. § 101 (work for hire); Copyright Act § 201(b)",
|
| 73 |
+
"consumer_standard": "Standard IP assignment with carve-outs for pre-existing IP",
|
| 74 |
+
"risk_level": "CRITICAL",
|
| 75 |
+
},
|
| 76 |
+
"Termination for Convenience": {
|
| 77 |
+
"risky_pattern": "Terminate at any time without notice",
|
| 78 |
+
"safe_alternative": (
|
| 79 |
+
"Either party may terminate this Agreement for convenience upon thirty (30) days' "
|
| 80 |
+
"prior written notice. Immediate termination is permitted only for material breach "
|
| 81 |
+
"that remains uncured after a ten (10) day cure period following written notice "
|
| 82 |
+
"specifying the breach. Upon termination: (a) all outstanding fees become due, "
|
| 83 |
+
"(b) each party shall return or destroy confidential information within fifteen (15) "
|
| 84 |
+
"business days, and (c) licenses granted hereunder shall terminate except as "
|
| 85 |
+
"expressly stated to survive."
|
| 86 |
+
),
|
| 87 |
+
"legal_basis": "Restatement (Second) of Contracts § 237; UCC § 2-309",
|
| 88 |
+
"consumer_standard": "FTC: adequate notice period required for service termination",
|
| 89 |
+
"risk_level": "CRITICAL",
|
| 90 |
+
},
|
| 91 |
+
"Limitation of liability": {
|
| 92 |
+
"risky_pattern": "Company not liable for any damages / complete disclaimer",
|
| 93 |
+
"safe_alternative": (
|
| 94 |
+
"Neither party shall be liable for indirect, incidental, special, or consequential "
|
| 95 |
+
"damages, EXCEPT in cases of: (a) gross negligence or willful misconduct, "
|
| 96 |
+
"(b) breach of confidentiality, (c) data breach involving personal information, or "
|
| 97 |
+
"(d) intellectual property infringement. Direct damages are limited to fees paid "
|
| 98 |
+
"in the prior twelve (12) months. Nothing in this Agreement limits liability for "
|
| 99 |
+
"death or personal injury caused by negligence."
|
| 100 |
+
),
|
| 101 |
+
"legal_basis": "UCC § 2-719(3); EU Directive 93/13/EEC Annex (a)",
|
| 102 |
+
"consumer_standard": "Cannot exclude liability for death/personal injury (EU/UK law)",
|
| 103 |
+
"risk_level": "CRITICAL",
|
| 104 |
+
},
|
| 105 |
+
"Unilateral termination": {
|
| 106 |
+
"risky_pattern": "Company can terminate account at any time without reason",
|
| 107 |
+
"safe_alternative": (
|
| 108 |
+
"The Provider may suspend or terminate the User's account for: (a) material breach "
|
| 109 |
+
"of these Terms, (b) non-payment after ten (10) days' notice, (c) illegal activity, "
|
| 110 |
+
"or (d) extended inactivity exceeding twelve (12) months. The Provider shall provide "
|
| 111 |
+
"at least thirty (30) days' written notice before termination, except in cases of "
|
| 112 |
+
"illegal activity. Upon termination, the User shall have thirty (30) days to export "
|
| 113 |
+
"their data."
|
| 114 |
+
),
|
| 115 |
+
"legal_basis": "EU Directive 2019/770 (Digital Content); CFPB guidance",
|
| 116 |
+
"consumer_standard": "Right to export data upon termination; adequate notice period",
|
| 117 |
+
"risk_level": "CRITICAL",
|
| 118 |
+
},
|
| 119 |
+
"Liquidated Damages": {
|
| 120 |
+
"risky_pattern": "Pre-determined damages far exceeding actual harm",
|
| 121 |
+
"safe_alternative": (
|
| 122 |
+
"In the event of breach, the non-breaching party shall be entitled to liquidated "
|
| 123 |
+
"damages in the amount of [specific reasonable amount], which the parties agree "
|
| 124 |
+
"represents a reasonable estimate of anticipated harm. This liquidated damages "
|
| 125 |
+
"provision shall not apply if actual damages are readily ascertainable, in which "
|
| 126 |
+
"case the non-breaching party may recover actual damages proven."
|
| 127 |
+
),
|
| 128 |
+
"legal_basis": "Restatement (Second) of Contracts § 356; UCC § 2-718",
|
| 129 |
+
"consumer_standard": "Liquidated damages must be reasonable estimate, not penalty",
|
| 130 |
+
"risk_level": "CRITICAL",
|
| 131 |
+
},
|
| 132 |
+
|
| 133 |
+
# ── HIGH Risk Clauses ──────────────────────────────────────────
|
| 134 |
+
"Unilateral change": {
|
| 135 |
+
"risky_pattern": "We may modify terms at any time without notice",
|
| 136 |
+
"safe_alternative": (
|
| 137 |
+
"Material changes to these Terms require thirty (30) days' advance written notice "
|
| 138 |
+
"to the User via email and in-app notification. The User has the right to terminate "
|
| 139 |
+
"without penalty within the notice period if they do not accept the changes. "
|
| 140 |
+
"Non-material changes (e.g., formatting, clarifications) may be made without notice."
|
| 141 |
+
),
|
| 142 |
+
"legal_basis": "EU Directive 93/13/EEC Art. 3; Restatement (Second) § 89",
|
| 143 |
+
"consumer_standard": "FTC: material changes require notice and right to reject",
|
| 144 |
+
"risk_level": "HIGH",
|
| 145 |
+
},
|
| 146 |
+
"Content removal": {
|
| 147 |
+
"risky_pattern": "Company can delete content at sole discretion without notice",
|
| 148 |
+
"safe_alternative": (
|
| 149 |
+
"Content may be removed only for violation of these Terms of Service, applicable law, "
|
| 150 |
+
"or valid legal process. The Provider shall provide prior notice specifying the reason "
|
| 151 |
+
"for removal (except where legally prohibited). The User has the right to appeal "
|
| 152 |
+
"within fourteen (14) days. Removed content shall be preserved for thirty (30) days "
|
| 153 |
+
"to allow for appeal resolution."
|
| 154 |
+
),
|
| 155 |
+
"legal_basis": "EU Digital Services Act Art. 17; First Amendment considerations",
|
| 156 |
+
"consumer_standard": "Due process: notice, reason, and right to appeal",
|
| 157 |
+
"risk_level": "HIGH",
|
| 158 |
+
},
|
| 159 |
+
"Non-Compete": {
|
| 160 |
+
"risky_pattern": "Broad non-compete with no time/geography limits",
|
| 161 |
+
"safe_alternative": (
|
| 162 |
+
"During the term of this Agreement and for a period of [6-12] months thereafter, "
|
| 163 |
+
"the Receiving Party shall not directly compete with the Disclosing Party in "
|
| 164 |
+
"[specific market/geography]. This restriction applies only to [specific business "
|
| 165 |
+
"activities] and does not prevent general employment in the industry. The Disclosing "
|
| 166 |
+
"Party shall provide [garden leave pay / consideration] during the restricted period."
|
| 167 |
+
),
|
| 168 |
+
"legal_basis": "Restatement (Second) of Contracts § 188; FTC Non-Compete Rule (2024)",
|
| 169 |
+
"consumer_standard": "Reasonable scope, duration, geography; adequate consideration",
|
| 170 |
+
"risk_level": "HIGH",
|
| 171 |
+
},
|
| 172 |
+
"Exclusivity": {
|
| 173 |
+
"risky_pattern": "Exclusive dealing with no time limit or exit clause",
|
| 174 |
+
"safe_alternative": (
|
| 175 |
+
"The exclusivity arrangement shall apply for an initial term of [12-24] months, "
|
| 176 |
+
"after which either party may convert to non-exclusive upon sixty (60) days' notice. "
|
| 177 |
+
"Exclusivity is limited to [specific product/service category] and [specific "
|
| 178 |
+
"geographic area]. Performance benchmarks shall be reviewed quarterly; failure to "
|
| 179 |
+
"meet agreed minimums allows termination of exclusivity."
|
| 180 |
+
),
|
| 181 |
+
"legal_basis": "Sherman Act § 1; EU Competition Law Art. 101 TFEU",
|
| 182 |
+
"consumer_standard": "Time-limited, scope-limited, with performance exit clause",
|
| 183 |
+
"risk_level": "HIGH",
|
| 184 |
+
},
|
| 185 |
+
"Anti-Assignment": {
|
| 186 |
+
"risky_pattern": "Complete prohibition on assignment without consent",
|
| 187 |
+
"safe_alternative": (
|
| 188 |
+
"Neither party may assign this Agreement without the prior written consent of the "
|
| 189 |
+
"other party, which shall not be unreasonably withheld, conditioned, or delayed. "
|
| 190 |
+
"Notwithstanding the foregoing, either party may assign this Agreement without "
|
| 191 |
+
"consent in connection with a merger, acquisition, or sale of substantially all "
|
| 192 |
+
"of its assets, provided the assignee assumes all obligations hereunder."
|
| 193 |
+
),
|
| 194 |
+
"legal_basis": "UCC § 2-210; Restatement (Second) of Contracts § 317",
|
| 195 |
+
"consumer_standard": "Consent not to be unreasonably withheld; M&A carve-out",
|
| 196 |
+
"risk_level": "HIGH",
|
| 197 |
+
},
|
| 198 |
+
|
| 199 |
+
# ── MEDIUM Risk Clauses ────────────────────────────────────────
|
| 200 |
+
"Jurisdiction": {
|
| 201 |
+
"risky_pattern": "Exclusive jurisdiction in distant/foreign state",
|
| 202 |
+
"safe_alternative": (
|
| 203 |
+
"The Consumer may bring claims in their jurisdiction of residence or the Provider's "
|
| 204 |
+
"principal place of business. Small claims actions may be brought in any court of "
|
| 205 |
+
"competent jurisdiction. For commercial contracts: disputes shall be resolved in "
|
| 206 |
+
"[mutually agreed location] or the defendant's principal place of business."
|
| 207 |
+
),
|
| 208 |
+
"legal_basis": "EU Regulation 1215/2012 (Brussels I); CJEU C-585/08",
|
| 209 |
+
"consumer_standard": "Consumer may sue in home jurisdiction (EU Directive 93/13)",
|
| 210 |
+
"risk_level": "MEDIUM",
|
| 211 |
+
},
|
| 212 |
+
"Choice of law": {
|
| 213 |
+
"risky_pattern": "Governed by laws of a jurisdiction that disadvantages consumer",
|
| 214 |
+
"safe_alternative": (
|
| 215 |
+
"This Agreement shall be governed by the laws of [State/Country]. Notwithstanding "
|
| 216 |
+
"the foregoing, nothing in this choice of law provision shall deprive the Consumer "
|
| 217 |
+
"of the protection afforded by mandatory provisions of the law of the Consumer's "
|
| 218 |
+
"habitual residence."
|
| 219 |
+
),
|
| 220 |
+
"legal_basis": "EU Regulation 593/2008 (Rome I) Art. 6; UCC § 1-301",
|
| 221 |
+
"consumer_standard": "Cannot override mandatory consumer protection of home jurisdiction",
|
| 222 |
+
"risk_level": "MEDIUM",
|
| 223 |
+
},
|
| 224 |
+
"Contract by using": {
|
| 225 |
+
"risky_pattern": "Bound to contract by merely using the service (browsewrap)",
|
| 226 |
+
"safe_alternative": (
|
| 227 |
+
"By creating an account, the User acknowledges they have read, understood, and agree "
|
| 228 |
+
"to be bound by these Terms. The User must affirmatively accept these Terms via "
|
| 229 |
+
"checkbox or click-through before account creation. Continued use after material "
|
| 230 |
+
"changes requires re-acceptance."
|
| 231 |
+
),
|
| 232 |
+
"legal_basis": "Specht v. Netscape, 306 F.3d 17 (2d Cir. 2002)",
|
| 233 |
+
"consumer_standard": "Clickwrap > browsewrap; affirmative acceptance required",
|
| 234 |
+
"risk_level": "MEDIUM",
|
| 235 |
+
},
|
| 236 |
+
|
| 237 |
+
# ── Additional Common Clauses ──────────────────────────────────
|
| 238 |
+
"Auto-Renewal": {
|
| 239 |
+
"risky_pattern": "Auto-renews silently without notice",
|
| 240 |
+
"safe_alternative": (
|
| 241 |
+
"This Agreement shall automatically renew for successive [term] periods unless "
|
| 242 |
+
"either party provides written notice of non-renewal at least thirty (30) days "
|
| 243 |
+
"before the end of the then-current term. The Provider shall send a reminder "
|
| 244 |
+
"notice thirty (30) to sixty (60) days before renewal. The Consumer may cancel "
|
| 245 |
+
"within fifteen (15) days of renewal for a pro-rated refund."
|
| 246 |
+
),
|
| 247 |
+
"legal_basis": "California Auto-Renewal Law (ARL) Bus. & Prof. Code § 17600; FTC Negative Option Rule",
|
| 248 |
+
"consumer_standard": "Reminder notice required; easy cancellation; pro-rated refund",
|
| 249 |
+
"risk_level": "HIGH",
|
| 250 |
+
},
|
| 251 |
+
"Indemnification": {
|
| 252 |
+
"risky_pattern": "User indemnifies company for all claims without limit",
|
| 253 |
+
"safe_alternative": (
|
| 254 |
+
"Each party shall indemnify, defend, and hold harmless the other party from "
|
| 255 |
+
"third-party claims arising from: (a) the indemnifying party's breach of this "
|
| 256 |
+
"Agreement, (b) the indemnifying party's negligence or willful misconduct, or "
|
| 257 |
+
"(c) the indemnifying party's violation of applicable law. The User's indemnification "
|
| 258 |
+
"obligation is limited to claims arising from the User's own negligence or "
|
| 259 |
+
"intentional acts. The maximum indemnification obligation shall not exceed [amount]."
|
| 260 |
+
),
|
| 261 |
+
"legal_basis": "Restatement (Second) of Contracts § 345; UCC § 2-607",
|
| 262 |
+
"consumer_standard": "Mutual indemnification; limited to own acts; capped",
|
| 263 |
+
"risk_level": "HIGH",
|
| 264 |
+
},
|
| 265 |
+
"Confidentiality": {
|
| 266 |
+
"risky_pattern": "Overly broad confidentiality with no exceptions or time limit",
|
| 267 |
+
"safe_alternative": (
|
| 268 |
+
"Each party agrees to maintain the confidentiality of the other's Confidential "
|
| 269 |
+
"Information for a period of [3-5] years from disclosure. Confidential Information "
|
| 270 |
+
"excludes: (a) publicly available information, (b) independently developed "
|
| 271 |
+
"information, (c) information received from a third party without restriction, "
|
| 272 |
+
"(d) information required to be disclosed by law or court order (with prompt notice "
|
| 273 |
+
"to the disclosing party)."
|
| 274 |
+
),
|
| 275 |
+
"legal_basis": "Restatement (Third) of Unfair Competition § 39-45",
|
| 276 |
+
"consumer_standard": "Time-limited; standard exceptions; required disclosure carve-out",
|
| 277 |
+
"risk_level": "MEDIUM",
|
| 278 |
+
},
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
# Mapping from CUAD/unfair labels to our template keys
|
| 282 |
+
_LABEL_TO_TEMPLATE = {
|
| 283 |
+
"Uncapped Liability": "Uncapped Liability",
|
| 284 |
+
"Arbitration": "Arbitration",
|
| 285 |
+
"IP Ownership Assignment": "IP Ownership Assignment",
|
| 286 |
+
"Termination for Convenience": "Termination for Convenience",
|
| 287 |
+
"Limitation of liability": "Limitation of liability",
|
| 288 |
+
"Unilateral termination": "Unilateral termination",
|
| 289 |
+
"Liquidated Damages": "Liquidated Damages",
|
| 290 |
+
"Unilateral change": "Unilateral change",
|
| 291 |
+
"Content removal": "Content removal",
|
| 292 |
+
"Non-Compete": "Non-Compete",
|
| 293 |
+
"Exclusivity": "Exclusivity",
|
| 294 |
+
"Anti-Assignment": "Anti-Assignment",
|
| 295 |
+
"Jurisdiction": "Jurisdiction",
|
| 296 |
+
"Choice of law": "Choice of law",
|
| 297 |
+
"Contract by using": "Contract by using",
|
| 298 |
+
"Cap on Liability": "Limitation of liability", # Similar enough
|
| 299 |
+
"No-Solicit of Customers": "Non-Compete", # Use non-compete template
|
| 300 |
+
"No-Solicit of Employees": "Non-Compete",
|
| 301 |
+
"Non-Disparagement": "Confidentiality", # Similar restrictive clause
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 306 |
+
# TIER 2: RAG RETRIEVAL (find fairer precedent clauses)
|
| 307 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 308 |
+
|
| 309 |
+
def _find_similar_templates(clause_label, clause_text):
|
| 310 |
+
"""
|
| 311 |
+
Find the most relevant safe alternative template(s) for a given clause.
|
| 312 |
+
Returns list of matching templates.
|
| 313 |
+
"""
|
| 314 |
+
matches = []
|
| 315 |
+
|
| 316 |
+
# Direct label match
|
| 317 |
+
template_key = _LABEL_TO_TEMPLATE.get(clause_label)
|
| 318 |
+
if template_key and template_key in SAFE_ALTERNATIVES:
|
| 319 |
+
matches.append((template_key, SAFE_ALTERNATIVES[template_key], 1.0))
|
| 320 |
+
|
| 321 |
+
# Also do keyword matching for clauses that might not have exact label matches
|
| 322 |
+
clause_lower = clause_text.lower()
|
| 323 |
+
keyword_map = {
|
| 324 |
+
"Uncapped Liability": ["unlimited liability", "uncapped", "no limit on liability"],
|
| 325 |
+
"Arbitration": ["arbitration", "arbitrate", "waive right to court", "class action waiver"],
|
| 326 |
+
"Termination for Convenience": ["terminate at any time", "terminate without cause", "terminate without notice"],
|
| 327 |
+
"Limitation of liability": ["not liable", "limitation of liability", "in no event", "disclaim"],
|
| 328 |
+
"Unilateral change": ["modify at any time", "sole discretion", "change terms", "without notice"],
|
| 329 |
+
"Content removal": ["remove content", "delete content", "remove at sole discretion"],
|
| 330 |
+
"Auto-Renewal": ["auto-renew", "automatically renew", "automatic renewal"],
|
| 331 |
+
"Indemnification": ["indemnif", "hold harmless"],
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
for key, keywords in keyword_map.items():
|
| 335 |
+
if key in SAFE_ALTERNATIVES:
|
| 336 |
+
for kw in keywords:
|
| 337 |
+
if kw in clause_lower:
|
| 338 |
+
# Avoid duplicates
|
| 339 |
+
if not any(m[0] == key for m in matches):
|
| 340 |
+
matches.append((key, SAFE_ALTERNATIVES[key], 0.7))
|
| 341 |
+
break
|
| 342 |
+
|
| 343 |
+
return matches
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 347 |
+
# TIER 3: LLM REFINEMENT
|
| 348 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 349 |
+
|
| 350 |
+
_LLM_MODEL = "Qwen/Qwen2.5-7B-Instruct"
|
| 351 |
+
|
| 352 |
+
def _refine_with_llm(original_clause, template, clause_label):
|
| 353 |
+
"""
|
| 354 |
+
Use LLM to adapt the template to the specific clause context.
|
| 355 |
+
The LLM refines — it does NOT generate from scratch (anti-hallucination).
|
| 356 |
+
"""
|
| 357 |
+
if not _HAS_INFERENCE:
|
| 358 |
+
return None
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
token = os.environ.get("HF_TOKEN", "")
|
| 362 |
+
client = InferenceClient(
|
| 363 |
+
provider="hf-inference",
|
| 364 |
+
api_key=token if token else None,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
prompt = f"""You are a legal contract redlining assistant. Your task is to adapt a safe clause template to fit the specific context of an original risky clause.
|
| 368 |
+
|
| 369 |
+
RULES:
|
| 370 |
+
1. You MUST use the provided template as your base — do NOT generate clauses from scratch.
|
| 371 |
+
2. Preserve the legal protections in the template.
|
| 372 |
+
3. Adapt specific details (parties, amounts, timeframes) from the original clause.
|
| 373 |
+
4. Keep the same legal standard cited in the template.
|
| 374 |
+
5. Output ONLY the refined clause text, nothing else.
|
| 375 |
+
6. The refined clause should be immediately usable in a contract.
|
| 376 |
+
|
| 377 |
+
ORIGINAL RISKY CLAUSE:
|
| 378 |
+
{original_clause[:500]}
|
| 379 |
+
|
| 380 |
+
CLAUSE TYPE: {clause_label}
|
| 381 |
+
|
| 382 |
+
SAFE TEMPLATE:
|
| 383 |
+
{template['safe_alternative']}
|
| 384 |
+
|
| 385 |
+
LEGAL BASIS: {template['legal_basis']}
|
| 386 |
+
|
| 387 |
+
Write the refined safer clause (adapt the template to this specific contract's context):"""
|
| 388 |
+
|
| 389 |
+
response = client.chat_completion(
|
| 390 |
+
model=_LLM_MODEL,
|
| 391 |
+
messages=[
|
| 392 |
+
{"role": "system", "content": "You are a legal contract redlining expert. Output ONLY the refined clause text."},
|
| 393 |
+
{"role": "user", "content": prompt},
|
| 394 |
+
],
|
| 395 |
+
max_tokens=512,
|
| 396 |
+
temperature=0.2,
|
| 397 |
+
)
|
| 398 |
+
refined = response.choices[0].message.content.strip()
|
| 399 |
+
|
| 400 |
+
# Sanity check: refined should be substantial
|
| 401 |
+
if len(refined) < 50:
|
| 402 |
+
return None
|
| 403 |
+
return refined
|
| 404 |
+
|
| 405 |
+
except Exception as e:
|
| 406 |
+
print(f"[ClauseGuard Redline] LLM refinement error: {e}")
|
| 407 |
+
return None
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 411 |
+
# PUBLIC API
|
| 412 |
+
# ═══════════════════════════════════════════════════════════════════════
|
| 413 |
+
|
| 414 |
+
def generate_redlines(analysis_result, use_llm=True):
|
| 415 |
+
"""
|
| 416 |
+
Generate redline suggestions for all flagged clauses in the analysis.
|
| 417 |
+
|
| 418 |
+
Returns list of redline suggestions:
|
| 419 |
+
[{
|
| 420 |
+
"original_text": str,
|
| 421 |
+
"clause_label": str,
|
| 422 |
+
"risk_level": str,
|
| 423 |
+
"safe_alternative": str,
|
| 424 |
+
"legal_basis": str,
|
| 425 |
+
"consumer_standard": str,
|
| 426 |
+
"tier": "template" | "llm_refined",
|
| 427 |
+
"confidence": str,
|
| 428 |
+
}]
|
| 429 |
+
"""
|
| 430 |
+
if analysis_result is None:
|
| 431 |
+
return []
|
| 432 |
+
|
| 433 |
+
clauses = analysis_result.get("clauses", [])
|
| 434 |
+
if not clauses:
|
| 435 |
+
return []
|
| 436 |
+
|
| 437 |
+
redlines = []
|
| 438 |
+
seen_labels = set() # Deduplicate by label
|
| 439 |
+
|
| 440 |
+
# Sort by risk level: CRITICAL first
|
| 441 |
+
risk_order = {"CRITICAL": 0, "HIGH": 1, "MEDIUM": 2, "LOW": 3}
|
| 442 |
+
sorted_clauses = sorted(clauses, key=lambda c: risk_order.get(c.get("risk", "LOW"), 3))
|
| 443 |
+
|
| 444 |
+
for clause in sorted_clauses:
|
| 445 |
+
label = clause.get("label", "")
|
| 446 |
+
risk = clause.get("risk", "LOW")
|
| 447 |
+
text = clause.get("text", "")
|
| 448 |
+
|
| 449 |
+
# Skip LOW risk and already-seen labels
|
| 450 |
+
if risk == "LOW" or label in seen_labels:
|
| 451 |
+
continue
|
| 452 |
+
seen_labels.add(label)
|
| 453 |
+
|
| 454 |
+
# Find matching templates (Tier 1 + Tier 2)
|
| 455 |
+
matches = _find_similar_templates(label, text)
|
| 456 |
+
if not matches:
|
| 457 |
+
continue
|
| 458 |
+
|
| 459 |
+
best_key, best_template, score = matches[0]
|
| 460 |
+
|
| 461 |
+
# Tier 3: Try LLM refinement if enabled
|
| 462 |
+
refined_text = None
|
| 463 |
+
tier = "template"
|
| 464 |
+
if use_llm and risk in ("CRITICAL", "HIGH"):
|
| 465 |
+
refined_text = _refine_with_llm(text, best_template, label)
|
| 466 |
+
if refined_text:
|
| 467 |
+
tier = "llm_refined"
|
| 468 |
+
|
| 469 |
+
redlines.append({
|
| 470 |
+
"original_text": text[:500],
|
| 471 |
+
"clause_label": label,
|
| 472 |
+
"risk_level": risk,
|
| 473 |
+
"safe_alternative": refined_text or best_template["safe_alternative"],
|
| 474 |
+
"template_alternative": best_template["safe_alternative"],
|
| 475 |
+
"legal_basis": best_template["legal_basis"],
|
| 476 |
+
"consumer_standard": best_template["consumer_standard"],
|
| 477 |
+
"tier": tier,
|
| 478 |
+
})
|
| 479 |
+
|
| 480 |
+
return redlines
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def render_redlines_html(redlines):
|
| 484 |
+
"""Render redline suggestions as HTML for Gradio."""
|
| 485 |
+
if not redlines:
|
| 486 |
+
return '''<div style="padding:24px;text-align:center;color:#6b7280;font-family:system-ui,sans-serif;">
|
| 487 |
+
<p style="font-size:16px;">📝 No redline suggestions available.</p>
|
| 488 |
+
<p style="font-size:13px;">Analyze a contract first — redlining suggestions will appear for risky clauses.</p>
|
| 489 |
+
</div>'''
|
| 490 |
+
|
| 491 |
+
risk_styles = {
|
| 492 |
+
"CRITICAL": ("#dc2626", "#fef2f2", "⚠️"),
|
| 493 |
+
"HIGH": ("#ea580c", "#fff7ed", "⚡"),
|
| 494 |
+
"MEDIUM": ("#ca8a04", "#fefce8", "📋"),
|
| 495 |
+
"LOW": ("#16a34a", "#f0fdf4", "✓"),
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
html = '<div style="font-family:system-ui,sans-serif;">'
|
| 499 |
+
|
| 500 |
+
# Summary header
|
| 501 |
+
crit = sum(1 for r in redlines if r["risk_level"] == "CRITICAL")
|
| 502 |
+
high = sum(1 for r in redlines if r["risk_level"] == "HIGH")
|
| 503 |
+
med = sum(1 for r in redlines if r["risk_level"] == "MEDIUM")
|
| 504 |
+
llm_count = sum(1 for r in redlines if r["tier"] == "llm_refined")
|
| 505 |
+
|
| 506 |
+
html += f'''
|
| 507 |
+
<div style="padding:16px;background:linear-gradient(135deg,#eff6ff,#f0fdf4);border-radius:12px;margin-bottom:16px;border:1px solid #e5e7eb;">
|
| 508 |
+
<div style="display:flex;align-items:center;gap:8px;margin-bottom:8px;">
|
| 509 |
+
<span style="font-size:24px;">✏️</span>
|
| 510 |
+
<h2 style="margin:0;font-size:18px;color:#1f2937;">Clause Redlining Suggestions</h2>
|
| 511 |
+
</div>
|
| 512 |
+
<p style="font-size:13px;color:#6b7280;margin:0;">
|
| 513 |
+
{len(redlines)} suggestions: {crit} Critical · {high} High · {med} Medium
|
| 514 |
+
{f" · {llm_count} LLM-refined" if llm_count else ""}
|
| 515 |
+
</p>
|
| 516 |
+
</div>
|
| 517 |
+
'''
|
| 518 |
+
|
| 519 |
+
for i, redline in enumerate(redlines):
|
| 520 |
+
border_color, bg_color, icon = risk_styles.get(
|
| 521 |
+
redline["risk_level"], ("#6b7280", "#f9fafb", "•")
|
| 522 |
+
)
|
| 523 |
+
tier_badge = (
|
| 524 |
+
'<span style="font-size:10px;background:#eff6ff;color:#3b82f6;padding:2px 8px;border-radius:4px;">🤖 LLM Refined</span>'
|
| 525 |
+
if redline["tier"] == "llm_refined"
|
| 526 |
+
else '<span style="font-size:10px;background:#f0fdf4;color:#16a34a;padding:2px 8px;border-radius:4px;">📋 Template</span>'
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
original_preview = redline["original_text"][:200].replace("<", "<").replace(">", ">")
|
| 530 |
+
safe_text = redline["safe_alternative"].replace("<", "<").replace(">", ">")
|
| 531 |
+
|
| 532 |
+
html += f'''
|
| 533 |
+
<div style="border:1px solid #e5e7eb;border-left:4px solid {border_color};border-radius:8px;margin-bottom:12px;overflow:hidden;">
|
| 534 |
+
<!-- Header -->
|
| 535 |
+
<div style="padding:12px 16px;background:{bg_color};border-bottom:1px solid #e5e7eb;">
|
| 536 |
+
<div style="display:flex;align-items:center;justify-content:space-between;">
|
| 537 |
+
<div style="display:flex;align-items:center;gap:8px;">
|
| 538 |
+
<span style="font-size:16px;">{icon}</span>
|
| 539 |
+
<span style="font-size:14px;font-weight:600;color:{border_color};">{redline["clause_label"]}</span>
|
| 540 |
+
<span style="font-size:11px;color:{border_color};text-transform:uppercase;font-weight:600;">{redline["risk_level"]}</span>
|
| 541 |
+
</div>
|
| 542 |
+
{tier_badge}
|
| 543 |
+
</div>
|
| 544 |
+
</div>
|
| 545 |
+
|
| 546 |
+
<!-- Body -->
|
| 547 |
+
<div style="padding:16px;">
|
| 548 |
+
<!-- Original (risky) -->
|
| 549 |
+
<div style="margin-bottom:12px;">
|
| 550 |
+
<div style="font-size:11px;font-weight:600;color:#991b1b;text-transform:uppercase;margin-bottom:4px;">❌ Original (Risky)</div>
|
| 551 |
+
<div style="background:#fef2f2;border:1px solid #fecaca;border-radius:6px;padding:10px;font-size:12px;color:#991b1b;line-height:1.6;">
|
| 552 |
+
<del>{original_preview}{"..." if len(redline["original_text"]) > 200 else ""}</del>
|
| 553 |
+
</div>
|
| 554 |
+
</div>
|
| 555 |
+
|
| 556 |
+
<!-- Suggested (safe) -->
|
| 557 |
+
<div style="margin-bottom:12px;">
|
| 558 |
+
<div style="font-size:11px;font-weight:600;color:#166534;text-transform:uppercase;margin-bottom:4px;">✅ Suggested Alternative</div>
|
| 559 |
+
<div style="background:#f0fdf4;border:1px solid #bbf7d0;border-radius:6px;padding:10px;font-size:12px;color:#166534;line-height:1.6;">
|
| 560 |
+
{safe_text}
|
| 561 |
+
</div>
|
| 562 |
+
</div>
|
| 563 |
+
|
| 564 |
+
<!-- Legal basis -->
|
| 565 |
+
<div style="display:flex;gap:12px;flex-wrap:wrap;">
|
| 566 |
+
<div style="flex:1;min-width:200px;">
|
| 567 |
+
<div style="font-size:10px;font-weight:600;color:#6b7280;text-transform:uppercase;margin-bottom:2px;">📚 Legal Basis</div>
|
| 568 |
+
<div style="font-size:11px;color:#4b5563;">{redline["legal_basis"]}</div>
|
| 569 |
+
</div>
|
| 570 |
+
<div style="flex:1;min-width:200px;">
|
| 571 |
+
<div style="font-size:10px;font-weight:600;color:#6b7280;text-transform:uppercase;margin-bottom:2px;">🛡️ Consumer Standard</div>
|
| 572 |
+
<div style="font-size:11px;color:#4b5563;">{redline["consumer_standard"]}</div>
|
| 573 |
+
</div>
|
| 574 |
+
</div>
|
| 575 |
+
</div>
|
| 576 |
+
</div>
|
| 577 |
+
'''
|
| 578 |
+
|
| 579 |
+
# Disclaimer
|
| 580 |
+
html += '''
|
| 581 |
+
<div style="margin-top:16px;padding:12px;background:#fefce8;border:1px solid #fde68a;border-radius:8px;">
|
| 582 |
+
<p style="font-size:11px;color:#92400e;margin:0;line-height:1.5;">
|
| 583 |
+
<strong>⚠️ Disclaimer:</strong> These are AI-generated suggestions based on legal templates and consumer protection standards.
|
| 584 |
+
They are NOT legal advice. The suggested alternatives are starting points that should be reviewed and customized by a
|
| 585 |
+
qualified attorney before use in any contract. Legal requirements vary by jurisdiction.
|
| 586 |
+
</p>
|
| 587 |
+
</div>
|
| 588 |
+
'''
|
| 589 |
+
|
| 590 |
+
html += '</div>'
|
| 591 |
+
return html
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
gradio>=5.23.0
|
| 2 |
-
transformers>=
|
| 3 |
torch>=2.5.0
|
| 4 |
numpy>=2.0.0
|
| 5 |
pdfplumber>=0.11.0
|
|
@@ -7,3 +7,5 @@ python-docx>=1.1.0
|
|
| 7 |
peft>=0.15.0
|
| 8 |
accelerate>=1.2.0
|
| 9 |
sentence-transformers>=3.0.0
|
|
|
|
|
|
|
|
|
| 1 |
gradio>=5.23.0
|
| 2 |
+
transformers>=4.45.0
|
| 3 |
torch>=2.5.0
|
| 4 |
numpy>=2.0.0
|
| 5 |
pdfplumber>=0.11.0
|
|
|
|
| 7 |
peft>=0.15.0
|
| 8 |
accelerate>=1.2.0
|
| 9 |
sentence-transformers>=3.0.0
|
| 10 |
+
python-doctr[torch]>=0.9.0
|
| 11 |
+
huggingface_hub>=0.25.0
|
web/.env.example
CHANGED
|
@@ -18,3 +18,10 @@ RESEND_API_KEY=re_...
|
|
| 18 |
# App
|
| 19 |
NEXT_PUBLIC_SITE_URL=http://localhost:3000
|
| 20 |
CLAUSEGUARD_API_URL=https://gaurv007-clauseguard-api.hf.space
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# App
|
| 19 |
NEXT_PUBLIC_SITE_URL=http://localhost:3000
|
| 20 |
CLAUSEGUARD_API_URL=https://gaurv007-clauseguard-api.hf.space
|
| 21 |
+
|
| 22 |
+
# HF Inference API (for chatbot + redlining LLM)
|
| 23 |
+
HF_TOKEN=hf_...
|
| 24 |
+
|
| 25 |
+
# Optional: SaulLM for explain endpoint
|
| 26 |
+
SAULLM_ENDPOINT=
|
| 27 |
+
HF_API_TOKEN=
|
web/app/api/analyze/route.ts
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import { NextRequest, NextResponse } from "next/server";
|
|
|
|
| 2 |
|
| 3 |
const API_URL = process.env.CLAUSEGUARD_API_URL || "https://gaurv007-clauseguard-api.hf.space";
|
| 4 |
|
|
@@ -14,10 +15,19 @@ export async function POST(req: NextRequest) {
|
|
| 14 |
);
|
| 15 |
}
|
| 16 |
|
| 17 |
-
// Forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
const response = await fetch(`${API_URL}/api/analyze`, {
|
| 19 |
method: "POST",
|
| 20 |
-
headers
|
| 21 |
body: JSON.stringify({ text, source_url }),
|
| 22 |
});
|
| 23 |
|
|
|
|
| 1 |
import { NextRequest, NextResponse } from "next/server";
|
| 2 |
+
import { createClient } from "@/lib/supabase/server";
|
| 3 |
|
| 4 |
const API_URL = process.env.CLAUSEGUARD_API_URL || "https://gaurv007-clauseguard-api.hf.space";
|
| 5 |
|
|
|
|
| 15 |
);
|
| 16 |
}
|
| 17 |
|
| 18 |
+
// Forward auth token to backend
|
| 19 |
+
const headers: Record<string, string> = { "Content-Type": "application/json" };
|
| 20 |
+
try {
|
| 21 |
+
const supabase = await createClient();
|
| 22 |
+
const { data: { session } } = await supabase.auth.getSession();
|
| 23 |
+
if (session?.access_token) {
|
| 24 |
+
headers["Authorization"] = `Bearer ${session.access_token}`;
|
| 25 |
+
}
|
| 26 |
+
} catch {}
|
| 27 |
+
|
| 28 |
const response = await fetch(`${API_URL}/api/analyze`, {
|
| 29 |
method: "POST",
|
| 30 |
+
headers,
|
| 31 |
body: JSON.stringify({ text, source_url }),
|
| 32 |
});
|
| 33 |
|
web/app/api/chat/route.ts
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextRequest, NextResponse } from "next/server";
|
| 2 |
+
|
| 3 |
+
const API_URL = process.env.CLAUSEGUARD_API_URL || "https://gaurv007-clauseguard-api.hf.space";
|
| 4 |
+
|
| 5 |
+
export async function POST(req: NextRequest) {
|
| 6 |
+
try {
|
| 7 |
+
const body = await req.json();
|
| 8 |
+
const { message, session_id, history } = body;
|
| 9 |
+
|
| 10 |
+
if (!message || !session_id) {
|
| 11 |
+
return NextResponse.json(
|
| 12 |
+
{ error: "message and session_id are required" },
|
| 13 |
+
{ status: 400 }
|
| 14 |
+
);
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
const response = await fetch(`${API_URL}/api/chat`, {
|
| 18 |
+
method: "POST",
|
| 19 |
+
headers: { "Content-Type": "application/json" },
|
| 20 |
+
body: JSON.stringify({ message, session_id, history: history || [] }),
|
| 21 |
+
});
|
| 22 |
+
|
| 23 |
+
if (!response.ok) {
|
| 24 |
+
const err = await response.text().catch(() => "");
|
| 25 |
+
throw new Error(err || `Backend error: ${response.status}`);
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
const result = await response.json();
|
| 29 |
+
return NextResponse.json(result);
|
| 30 |
+
} catch (error: any) {
|
| 31 |
+
console.error("Chat error:", error.message);
|
| 32 |
+
return NextResponse.json(
|
| 33 |
+
{ error: error.message || "Chat failed. Try again." },
|
| 34 |
+
{ status: 500 }
|
| 35 |
+
);
|
| 36 |
+
}
|
| 37 |
+
}
|
web/app/api/redline/route.ts
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { NextRequest, NextResponse } from "next/server";
|
| 2 |
+
|
| 3 |
+
const API_URL = process.env.CLAUSEGUARD_API_URL || "https://gaurv007-clauseguard-api.hf.space";
|
| 4 |
+
|
| 5 |
+
export async function POST(req: NextRequest) {
|
| 6 |
+
try {
|
| 7 |
+
const body = await req.json();
|
| 8 |
+
const { session_id, text, use_llm } = body;
|
| 9 |
+
|
| 10 |
+
if (!session_id && !text) {
|
| 11 |
+
return NextResponse.json(
|
| 12 |
+
{ error: "Provide session_id or text" },
|
| 13 |
+
{ status: 400 }
|
| 14 |
+
);
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
const response = await fetch(`${API_URL}/api/redline`, {
|
| 18 |
+
method: "POST",
|
| 19 |
+
headers: { "Content-Type": "application/json" },
|
| 20 |
+
body: JSON.stringify({ session_id, text, use_llm: use_llm ?? true }),
|
| 21 |
+
});
|
| 22 |
+
|
| 23 |
+
if (!response.ok) {
|
| 24 |
+
const err = await response.text().catch(() => "");
|
| 25 |
+
throw new Error(err || `Backend error: ${response.status}`);
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
const result = await response.json();
|
| 29 |
+
return NextResponse.json(result);
|
| 30 |
+
} catch (error: any) {
|
| 31 |
+
console.error("Redline error:", error.message);
|
| 32 |
+
return NextResponse.json(
|
| 33 |
+
{ error: error.message || "Redlining failed" },
|
| 34 |
+
{ status: 500 }
|
| 35 |
+
);
|
| 36 |
+
}
|
| 37 |
+
}
|
web/app/dashboard-pages/analyze/page.tsx
CHANGED
|
@@ -9,7 +9,8 @@ import {
|
|
| 9 |
AlertTriangle, Tag, BookOpen, ClipboardList, DollarSign,
|
| 10 |
Calendar, Building, MapPin, Hash, Bot, FileSearch, Percent, Clock,
|
| 11 |
User, BookMarked, ShieldX, HelpCircle, Cpu, PenTool, Zap,
|
| 12 |
-
ShieldOff, CircleSlash, MessageSquareWarning, Construction
|
|
|
|
| 13 |
} from "lucide-react";
|
| 14 |
|
| 15 |
interface Cat { name: string; severity: string; description?: string; confidence?: number; }
|
|
@@ -19,6 +20,17 @@ interface Contradiction { type: string; explanation: string; severity: string; c
|
|
| 19 |
interface Obligation { type: string; party: string; description: string; deadline: string; priority?: number; }
|
| 20 |
interface ComplianceCheck { requirement: string; description: string; severity: string; status: string; matched_keywords: string[]; context?: string[]; }
|
| 21 |
interface ComplianceReg { description: string; compliance_rate: number; checks: ComplianceCheck[]; overall_status: string; negated_count?: number; ambiguous_count?: number; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
interface AnalysisResult {
|
| 23 |
risk_score: number;
|
| 24 |
grade: string;
|
|
@@ -29,8 +41,10 @@ interface AnalysisResult {
|
|
| 29 |
contradictions: Contradiction[];
|
| 30 |
obligations: Obligation[];
|
| 31 |
compliance: Record<string, ComplianceReg>;
|
|
|
|
| 32 |
model: string;
|
| 33 |
latency_ms: number;
|
|
|
|
| 34 |
}
|
| 35 |
|
| 36 |
const SEV_CONFIG: Record<string, { icon: any; label: string; text: string; bg: string; border: string; ring: string }> = {
|
|
@@ -169,6 +183,9 @@ export default function AnalyzePage() {
|
|
| 169 |
const [scanLimit, setScanLimit] = useState(10);
|
| 170 |
const [canUpload, setCanUpload] = useState(false);
|
| 171 |
const [showUpgrade, setShowUpgrade] = useState(false);
|
|
|
|
|
|
|
|
|
|
| 172 |
const fileInputRef = useRef<HTMLInputElement>(null);
|
| 173 |
|
| 174 |
// Fetch user profile from DB on mount — no hardcoded emails or plans
|
|
@@ -237,6 +254,31 @@ export default function AnalyzePage() {
|
|
| 237 |
setCopied(true); setTimeout(() => setCopied(false), 2000);
|
| 238 |
}
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
const flagged = results?.results.filter(r => r.categories.length > 0) || [];
|
| 241 |
const filtered = filter === "all" ? flagged : flagged.filter(r => r.categories.some(c => c.severity === filter));
|
| 242 |
const sevCounts = { CRITICAL: 0, HIGH: 0, MEDIUM: 0, LOW: 0 };
|
|
@@ -260,6 +302,8 @@ export default function AnalyzePage() {
|
|
| 260 |
{ key: "contradictions", label: "Issues", icon: AlertTriangle, count: results?.contradictions.length || 0 },
|
| 261 |
{ key: "obligations", label: "Obligations", icon: ClipboardList, count: results?.obligations.length || 0 },
|
| 262 |
{ key: "compliance", label: "Compliance", icon: ShieldCheck, count: Object.keys(results?.compliance || {}).length },
|
|
|
|
|
|
|
| 263 |
];
|
| 264 |
|
| 265 |
return (
|
|
@@ -668,6 +712,139 @@ export default function AnalyzePage() {
|
|
| 668 |
})}
|
| 669 |
</div>
|
| 670 |
)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 671 |
</div>
|
| 672 |
</div>
|
| 673 |
) : (
|
|
|
|
| 9 |
AlertTriangle, Tag, BookOpen, ClipboardList, DollarSign,
|
| 10 |
Calendar, Building, MapPin, Hash, Bot, FileSearch, Percent, Clock,
|
| 11 |
User, BookMarked, ShieldX, HelpCircle, Cpu, PenTool, Zap,
|
| 12 |
+
ShieldOff, CircleSlash, MessageSquareWarning, Construction,
|
| 13 |
+
MessageSquare, Send, Loader2
|
| 14 |
} from "lucide-react";
|
| 15 |
|
| 16 |
interface Cat { name: string; severity: string; description?: string; confidence?: number; }
|
|
|
|
| 20 |
interface Obligation { type: string; party: string; description: string; deadline: string; priority?: number; }
|
| 21 |
interface ComplianceCheck { requirement: string; description: string; severity: string; status: string; matched_keywords: string[]; context?: string[]; }
|
| 22 |
interface ComplianceReg { description: string; compliance_rate: number; checks: ComplianceCheck[]; overall_status: string; negated_count?: number; ambiguous_count?: number; }
|
| 23 |
+
interface Redline {
|
| 24 |
+
original_text: string;
|
| 25 |
+
clause_label: string;
|
| 26 |
+
risk_level: string;
|
| 27 |
+
safe_alternative: string;
|
| 28 |
+
template_alternative?: string;
|
| 29 |
+
legal_basis: string;
|
| 30 |
+
consumer_standard: string;
|
| 31 |
+
tier: string;
|
| 32 |
+
}
|
| 33 |
+
interface ChatMessage { role: "user" | "assistant"; content: string; }
|
| 34 |
interface AnalysisResult {
|
| 35 |
risk_score: number;
|
| 36 |
grade: string;
|
|
|
|
| 41 |
contradictions: Contradiction[];
|
| 42 |
obligations: Obligation[];
|
| 43 |
compliance: Record<string, ComplianceReg>;
|
| 44 |
+
redlines: Redline[];
|
| 45 |
model: string;
|
| 46 |
latency_ms: number;
|
| 47 |
+
session_id?: string;
|
| 48 |
}
|
| 49 |
|
| 50 |
const SEV_CONFIG: Record<string, { icon: any; label: string; text: string; bg: string; border: string; ring: string }> = {
|
|
|
|
| 183 |
const [scanLimit, setScanLimit] = useState(10);
|
| 184 |
const [canUpload, setCanUpload] = useState(false);
|
| 185 |
const [showUpgrade, setShowUpgrade] = useState(false);
|
| 186 |
+
const [chatMessages, setChatMessages] = useState<ChatMessage[]>([]);
|
| 187 |
+
const [chatInput, setChatInput] = useState("");
|
| 188 |
+
const [chatLoading, setChatLoading] = useState(false);
|
| 189 |
const fileInputRef = useRef<HTMLInputElement>(null);
|
| 190 |
|
| 191 |
// Fetch user profile from DB on mount — no hardcoded emails or plans
|
|
|
|
| 254 |
setCopied(true); setTimeout(() => setCopied(false), 2000);
|
| 255 |
}
|
| 256 |
|
| 257 |
+
async function handleChat() {
|
| 258 |
+
if (!chatInput.trim() || !results?.session_id) return;
|
| 259 |
+
const userMsg: ChatMessage = { role: "user", content: chatInput.trim() };
|
| 260 |
+
setChatMessages(prev => [...prev, userMsg]);
|
| 261 |
+
setChatInput("");
|
| 262 |
+
setChatLoading(true);
|
| 263 |
+
try {
|
| 264 |
+
const res = await fetch("/api/chat", {
|
| 265 |
+
method: "POST",
|
| 266 |
+
headers: { "Content-Type": "application/json" },
|
| 267 |
+
body: JSON.stringify({
|
| 268 |
+
message: userMsg.content,
|
| 269 |
+
session_id: results.session_id,
|
| 270 |
+
history: chatMessages.slice(-6),
|
| 271 |
+
}),
|
| 272 |
+
});
|
| 273 |
+
if (!res.ok) throw new Error((await res.json()).error || "Chat failed");
|
| 274 |
+
const data = await res.json();
|
| 275 |
+
setChatMessages(prev => [...prev, { role: "assistant", content: data.response }]);
|
| 276 |
+
} catch (e: any) {
|
| 277 |
+
setChatMessages(prev => [...prev, { role: "assistant", content: `⚠️ ${e.message}` }]);
|
| 278 |
+
}
|
| 279 |
+
setChatLoading(false);
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
const flagged = results?.results.filter(r => r.categories.length > 0) || [];
|
| 283 |
const filtered = filter === "all" ? flagged : flagged.filter(r => r.categories.some(c => c.severity === filter));
|
| 284 |
const sevCounts = { CRITICAL: 0, HIGH: 0, MEDIUM: 0, LOW: 0 };
|
|
|
|
| 302 |
{ key: "contradictions", label: "Issues", icon: AlertTriangle, count: results?.contradictions.length || 0 },
|
| 303 |
{ key: "obligations", label: "Obligations", icon: ClipboardList, count: results?.obligations.length || 0 },
|
| 304 |
{ key: "compliance", label: "Compliance", icon: ShieldCheck, count: Object.keys(results?.compliance || {}).length },
|
| 305 |
+
{ key: "redlining", label: "Redlining", icon: PenTool, count: results?.redlines?.length || 0 },
|
| 306 |
+
{ key: "chat", label: "Q&A", icon: MessageSquare, count: chatMessages.length },
|
| 307 |
];
|
| 308 |
|
| 309 |
return (
|
|
|
|
| 712 |
})}
|
| 713 |
</div>
|
| 714 |
)}
|
| 715 |
+
|
| 716 |
+
{/* Redlining */}
|
| 717 |
+
{activeTab === "redlining" && (
|
| 718 |
+
<div className="space-y-3">
|
| 719 |
+
{(!results.redlines || results.redlines.length === 0) ? (
|
| 720 |
+
<div className="border border-dashed border-zinc-200 rounded-xl p-8 sm:p-10 text-center bg-white">
|
| 721 |
+
<PenTool className="w-8 h-8 text-zinc-300 mx-auto mb-2" />
|
| 722 |
+
<p className="text-sm text-zinc-500">No redlining suggestions for this contract.</p>
|
| 723 |
+
</div>
|
| 724 |
+
) : (
|
| 725 |
+
<>
|
| 726 |
+
<div className="bg-gradient-to-r from-blue-50 to-emerald-50 rounded-xl p-4 border border-zinc-200 mb-2">
|
| 727 |
+
<div className="flex items-center gap-2 mb-1">
|
| 728 |
+
<PenTool className="w-4 h-4 text-zinc-600" />
|
| 729 |
+
<span className="text-sm font-semibold text-zinc-800">Clause Redlining Suggestions</span>
|
| 730 |
+
</div>
|
| 731 |
+
<p className="text-xs text-zinc-500">
|
| 732 |
+
{results.redlines.length} suggestions · {results.redlines.filter(r => r.tier === "llm_refined").length} LLM-refined
|
| 733 |
+
</p>
|
| 734 |
+
</div>
|
| 735 |
+
{results.redlines.map((rl, i) => {
|
| 736 |
+
const isHigh = rl.risk_level === "CRITICAL" || rl.risk_level === "HIGH";
|
| 737 |
+
const conf = SEV_CONFIG[rl.risk_level] || SEV_CONFIG.MEDIUM;
|
| 738 |
+
return (
|
| 739 |
+
<div key={i} className={`bg-white border rounded-xl overflow-hidden ${conf.border}`}>
|
| 740 |
+
<div className={`px-4 py-3 ${conf.bg} border-b ${conf.border} flex items-center justify-between`}>
|
| 741 |
+
<div className="flex items-center gap-2">
|
| 742 |
+
<conf.icon className={`w-4 h-4 ${conf.text}`} />
|
| 743 |
+
<span className={`text-sm font-semibold ${conf.text}`}>{rl.clause_label}</span>
|
| 744 |
+
<span className={`text-[10px] uppercase font-bold ${conf.text}`}>{rl.risk_level}</span>
|
| 745 |
+
</div>
|
| 746 |
+
<span className={`text-[10px] px-2 py-0.5 rounded border ${
|
| 747 |
+
rl.tier === "llm_refined"
|
| 748 |
+
? "bg-indigo-50 text-indigo-600 border-indigo-200"
|
| 749 |
+
: "bg-emerald-50 text-emerald-600 border-emerald-200"
|
| 750 |
+
}`}>
|
| 751 |
+
{rl.tier === "llm_refined" ? "🤖 LLM Refined" : "📋 Template"}
|
| 752 |
+
</span>
|
| 753 |
+
</div>
|
| 754 |
+
<div className="p-4 space-y-3">
|
| 755 |
+
<div>
|
| 756 |
+
<p className="text-[10px] font-semibold text-red-600 uppercase mb-1">❌ Original (Risky)</p>
|
| 757 |
+
<div className="bg-red-50 border border-red-100 rounded-lg p-3 text-xs text-red-800 leading-relaxed line-through">
|
| 758 |
+
{rl.original_text.slice(0, 200)}{rl.original_text.length > 200 ? "..." : ""}
|
| 759 |
+
</div>
|
| 760 |
+
</div>
|
| 761 |
+
<div>
|
| 762 |
+
<p className="text-[10px] font-semibold text-emerald-600 uppercase mb-1">✅ Suggested Alternative</p>
|
| 763 |
+
<div className="bg-emerald-50 border border-emerald-100 rounded-lg p-3 text-xs text-emerald-800 leading-relaxed">
|
| 764 |
+
{rl.safe_alternative}
|
| 765 |
+
</div>
|
| 766 |
+
</div>
|
| 767 |
+
<div className="flex gap-3 flex-wrap text-[10px] text-zinc-500">
|
| 768 |
+
<span>📚 {rl.legal_basis}</span>
|
| 769 |
+
<span>🛡️ {rl.consumer_standard}</span>
|
| 770 |
+
</div>
|
| 771 |
+
</div>
|
| 772 |
+
</div>
|
| 773 |
+
);
|
| 774 |
+
})}
|
| 775 |
+
<div className="bg-amber-50 border border-amber-200 rounded-lg p-3 text-[11px] text-amber-800">
|
| 776 |
+
<strong>⚠️ Disclaimer:</strong> These are AI-generated suggestions, NOT legal advice. Consult an attorney before use.
|
| 777 |
+
</div>
|
| 778 |
+
</>
|
| 779 |
+
)}
|
| 780 |
+
</div>
|
| 781 |
+
)}
|
| 782 |
+
|
| 783 |
+
{/* Chat */}
|
| 784 |
+
{activeTab === "chat" && (
|
| 785 |
+
<div className="flex flex-col h-[350px] sm:h-[420px]">
|
| 786 |
+
{!results.session_id ? (
|
| 787 |
+
<div className="flex-1 flex items-center justify-center">
|
| 788 |
+
<div className="text-center">
|
| 789 |
+
<MessageSquare className="w-8 h-8 text-zinc-300 mx-auto mb-2" />
|
| 790 |
+
<p className="text-sm text-zinc-500">Chat unavailable — session not initialized.</p>
|
| 791 |
+
<p className="text-xs text-zinc-400 mt-1">Try analyzing again with the backend running.</p>
|
| 792 |
+
</div>
|
| 793 |
+
</div>
|
| 794 |
+
) : (
|
| 795 |
+
<>
|
| 796 |
+
<div className="flex-1 overflow-y-auto space-y-3 pr-1 mb-3">
|
| 797 |
+
{chatMessages.length === 0 && (
|
| 798 |
+
<div className="text-center py-8">
|
| 799 |
+
<MessageSquare className="w-8 h-8 text-zinc-200 mx-auto mb-2" />
|
| 800 |
+
<p className="text-sm text-zinc-400">Ask a question about your contract</p>
|
| 801 |
+
<div className="mt-3 flex flex-wrap justify-center gap-2">
|
| 802 |
+
{["What are the main risks?", "Who are the parties?", "Is there an arbitration clause?", "Summarize key terms"].map(q => (
|
| 803 |
+
<button key={q} onClick={() => { setChatInput(q); }}
|
| 804 |
+
className="text-xs px-3 py-1.5 rounded-full border border-zinc-200 text-zinc-500 hover:bg-zinc-50 transition-colors">
|
| 805 |
+
{q}
|
| 806 |
+
</button>
|
| 807 |
+
))}
|
| 808 |
+
</div>
|
| 809 |
+
</div>
|
| 810 |
+
)}
|
| 811 |
+
{chatMessages.map((msg, i) => (
|
| 812 |
+
<div key={i} className={`flex ${msg.role === "user" ? "justify-end" : "justify-start"}`}>
|
| 813 |
+
<div className={`max-w-[85%] rounded-xl px-3.5 py-2.5 text-sm leading-relaxed ${
|
| 814 |
+
msg.role === "user"
|
| 815 |
+
? "bg-zinc-900 text-white"
|
| 816 |
+
: "bg-zinc-100 text-zinc-700 border border-zinc-200"
|
| 817 |
+
}`}>
|
| 818 |
+
{msg.content}
|
| 819 |
+
</div>
|
| 820 |
+
</div>
|
| 821 |
+
))}
|
| 822 |
+
{chatLoading && (
|
| 823 |
+
<div className="flex justify-start">
|
| 824 |
+
<div className="bg-zinc-100 border border-zinc-200 rounded-xl px-4 py-3">
|
| 825 |
+
<Loader2 className="w-4 h-4 text-zinc-400 animate-spin" />
|
| 826 |
+
</div>
|
| 827 |
+
</div>
|
| 828 |
+
)}
|
| 829 |
+
</div>
|
| 830 |
+
<div className="flex gap-2 border-t border-zinc-100 pt-3">
|
| 831 |
+
<input
|
| 832 |
+
value={chatInput}
|
| 833 |
+
onChange={(e) => setChatInput(e.target.value)}
|
| 834 |
+
onKeyDown={(e) => e.key === "Enter" && !e.shiftKey && handleChat()}
|
| 835 |
+
placeholder="Ask about your contract..."
|
| 836 |
+
className="flex-1 px-3 py-2 border border-zinc-200 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-zinc-900/10"
|
| 837 |
+
disabled={chatLoading}
|
| 838 |
+
/>
|
| 839 |
+
<button onClick={handleChat} disabled={chatLoading || !chatInput.trim()}
|
| 840 |
+
className="px-3 py-2 bg-zinc-900 text-white rounded-lg hover:bg-zinc-800 disabled:opacity-40 transition-colors">
|
| 841 |
+
<Send className="w-4 h-4" />
|
| 842 |
+
</button>
|
| 843 |
+
</div>
|
| 844 |
+
</>
|
| 845 |
+
)}
|
| 846 |
+
</div>
|
| 847 |
+
)}
|
| 848 |
</div>
|
| 849 |
</div>
|
| 850 |
) : (
|
web/app/page.tsx
CHANGED
|
@@ -3,7 +3,8 @@ import {
|
|
| 3 |
ShieldCheck, ShieldAlert, Scale, Gavel, ScanText, FileCheck,
|
| 4 |
TriangleAlert, ArrowRight, Zap, Eye, Download, ChevronRight,
|
| 5 |
Sparkles, Lock, Globe, Ban, FileX, Stamp, Layers, Tag, AlertTriangle,
|
| 6 |
-
ClipboardList, Landmark, Building, BookOpen, CheckCircle, Cpu
|
|
|
|
| 7 |
} from "lucide-react";
|
| 8 |
|
| 9 |
const CLAUSES = [
|
|
@@ -21,22 +22,26 @@ const CLAUSES = [
|
|
| 21 |
{ icon: ClipboardList, name: "Obligations", desc: "Track monetary, compliance, reporting tasks with priority", severity: "medium" },
|
| 22 |
{ icon: Landmark, name: "Compliance", desc: "GDPR, CCPA, SOX, HIPAA, FINRA with negation detection", severity: "high" },
|
| 23 |
{ icon: BookOpen, name: "Compare Contracts", desc: "Semantic similarity with sentence embeddings", severity: "low" },
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
];
|
| 25 |
|
| 26 |
const STEPS = [
|
| 27 |
-
{ icon: Download, title: "Upload or paste", desc: "Drop a PDF, DOCX, or paste contract text directly." },
|
| 28 |
-
{ icon: ScanText, title: "
|
| 29 |
-
{ icon: TriangleAlert, title: "Get precise insights", desc: "Risk score, contradictions, obligations,
|
| 30 |
];
|
| 31 |
|
| 32 |
const PRICING = [
|
| 33 |
{
|
| 34 |
name: "Free", price: "0", period: "", highlight: false, cta: "Get started",
|
| 35 |
-
features: ["10 scans per month", "41 clause categories", "Risk scoring", "ML Legal NER", "NLI contradiction detection", "Compliance with negation detection"],
|
| 36 |
},
|
| 37 |
{
|
| 38 |
name: "Pro", price: "999", period: "/mo", highlight: true, cta: "Start free trial",
|
| 39 |
-
features: ["Unlimited scans", "Upload PDF/DOCX files", "Contract comparison", "AI clause explanations", "Scan history", "PDF report export", "Obligation tracker with priority", "Priority support"],
|
| 40 |
},
|
| 41 |
{
|
| 42 |
name: "Team", price: "3,999", period: "/mo", highlight: false, cta: "Talk to us",
|
|
@@ -59,14 +64,14 @@ export default function Home() {
|
|
| 59 |
<div className="max-w-2xl">
|
| 60 |
<div className="inline-flex items-center gap-2 px-3 py-1 rounded-full border border-zinc-200 text-[13px] text-zinc-500 mb-6">
|
| 61 |
<Sparkles className="w-3.5 h-3.5 text-zinc-400" />
|
| 62 |
-
|
| 63 |
</div>
|
| 64 |
<h1 className="text-3xl sm:text-[42px] lg:text-5xl font-semibold tracking-tight leading-[1.1]">
|
| 65 |
Know what you are<br className="hidden sm:block" /> agreeing to
|
| 66 |
</h1>
|
| 67 |
<p className="mt-5 text-base sm:text-[17px] text-zinc-500 leading-relaxed max-w-lg">
|
| 68 |
-
ClauseGuard scans contracts
|
| 69 |
-
|
| 70 |
</p>
|
| 71 |
<div className="mt-8 flex flex-col sm:flex-row gap-3">
|
| 72 |
<Link href="/dashboard-pages/analyze" className="inline-flex items-center justify-center gap-2 bg-zinc-900 text-white px-5 py-2.5 rounded-lg text-sm font-medium hover:bg-zinc-800 transition-colors">
|
|
@@ -87,11 +92,11 @@ export default function Home() {
|
|
| 87 |
<ShieldCheck className="w-4 h-4 text-zinc-400" />
|
| 88 |
<p className="text-[13px] font-medium text-zinc-400 uppercase tracking-wider">Detection</p>
|
| 89 |
</div>
|
| 90 |
-
<h2 className="text-xl sm:text-2xl font-semibold tracking-tight">
|
| 91 |
<p className="mt-2 text-zinc-500 text-sm sm:text-[15px] max-w-lg">
|
| 92 |
-
Based on the CUAD taxonomy + CLAUDETTE framework
|
| 93 |
</p>
|
| 94 |
-
<div className="mt-8 sm:mt-10 grid grid-cols-2 sm:grid-cols-
|
| 95 |
{CLAUSES.map((c) => (
|
| 96 |
<div key={c.name} className="group border border-zinc-100 rounded-xl p-3 sm:p-4 hover:border-zinc-200 hover:shadow-sm transition-all cursor-default">
|
| 97 |
<div className={`w-7 h-7 sm:w-8 sm:h-8 rounded-lg flex items-center justify-center border ${sevColor[c.severity]}`}>
|
|
@@ -135,15 +140,15 @@ export default function Home() {
|
|
| 135 |
<Cpu className="w-4 h-4 text-zinc-400" />
|
| 136 |
<p className="text-[13px] font-medium text-zinc-400 uppercase tracking-wider">Technology</p>
|
| 137 |
</div>
|
| 138 |
-
<h2 className="text-xl sm:text-2xl font-semibold tracking-tight">Built on
|
| 139 |
<div className="mt-8 grid sm:grid-cols-2 lg:grid-cols-3 gap-3 sm:gap-4">
|
| 140 |
{[
|
| 141 |
-
{ name: "Legal-BERT Classifier", icon: Cpu, desc: "LoRA fine-tuned on 41 CUAD categories with sigmoid multi-label classification
|
| 142 |
-
{ name: "Legal-BERT NER", icon: Tag, desc: "
|
| 143 |
-
{ name: "DeBERTa-v3 NLI", icon: AlertTriangle, desc: "
|
| 144 |
-
{ name: "
|
| 145 |
-
{ name: "
|
| 146 |
-
{ name: "
|
| 147 |
].map((m) => (
|
| 148 |
<div key={m.name} className="border border-zinc-100 rounded-xl p-4 hover:border-zinc-200 hover:shadow-sm transition-all">
|
| 149 |
<div className="flex items-center gap-2 mb-2">
|
|
@@ -211,7 +216,7 @@ export default function Home() {
|
|
| 211 |
<div className="max-w-6xl mx-auto px-4 sm:px-6 py-8 flex flex-col sm:flex-row justify-between items-center gap-4">
|
| 212 |
<div className="flex items-center gap-2">
|
| 213 |
<ShieldCheck className="w-4 h-4 text-zinc-300" />
|
| 214 |
-
<span className="text-[13px] text-zinc-400">ClauseGuard
|
| 215 |
</div>
|
| 216 |
<div className="flex gap-5 text-[13px] text-zinc-400">
|
| 217 |
<Link href="/privacy" className="hover:text-zinc-600">Privacy</Link>
|
|
|
|
| 3 |
ShieldCheck, ShieldAlert, Scale, Gavel, ScanText, FileCheck,
|
| 4 |
TriangleAlert, ArrowRight, Zap, Eye, Download, ChevronRight,
|
| 5 |
Sparkles, Lock, Globe, Ban, FileX, Stamp, Layers, Tag, AlertTriangle,
|
| 6 |
+
ClipboardList, Landmark, Building, BookOpen, CheckCircle, Cpu,
|
| 7 |
+
MessageSquare, PenTool, ScanLine
|
| 8 |
} from "lucide-react";
|
| 9 |
|
| 10 |
const CLAUSES = [
|
|
|
|
| 22 |
{ icon: ClipboardList, name: "Obligations", desc: "Track monetary, compliance, reporting tasks with priority", severity: "medium" },
|
| 23 |
{ icon: Landmark, name: "Compliance", desc: "GDPR, CCPA, SOX, HIPAA, FINRA with negation detection", severity: "high" },
|
| 24 |
{ icon: BookOpen, name: "Compare Contracts", desc: "Semantic similarity with sentence embeddings", severity: "low" },
|
| 25 |
+
{ icon: PenTool, name: "Clause Redlining", desc: "AI suggests safer alternatives with legal citations", severity: "critical" },
|
| 26 |
+
{ icon: MessageSquare, name: "Q&A Chatbot", desc: "Ask questions about your contract — RAG-powered answers", severity: "medium" },
|
| 27 |
+
{ icon: ScanLine, name: "OCR for Scanned PDFs", desc: "docTR engine auto-detects and OCRs scanned contracts", severity: "low" },
|
| 28 |
+
{ icon: Cpu, name: "6 AI Models", desc: "Legal-BERT, NER, NLI, Embeddings, OCR, Qwen2.5-7B LLM", severity: "low" },
|
| 29 |
];
|
| 30 |
|
| 31 |
const STEPS = [
|
| 32 |
+
{ icon: Download, title: "Upload or paste", desc: "Drop a PDF (even scanned!), DOCX, or paste contract text directly." },
|
| 33 |
+
{ icon: ScanText, title: "6 AI models analyze", desc: "Legal-BERT + NER + NLI + OCR + Embeddings + LLM scan your contract." },
|
| 34 |
+
{ icon: TriangleAlert, title: "Get precise insights", desc: "Risk score, redlining, Q&A chatbot, contradictions, obligations, and compliance." },
|
| 35 |
];
|
| 36 |
|
| 37 |
const PRICING = [
|
| 38 |
{
|
| 39 |
name: "Free", price: "0", period: "", highlight: false, cta: "Get started",
|
| 40 |
+
features: ["10 scans per month", "41 clause categories", "Risk scoring", "ML Legal NER", "NLI contradiction detection", "Compliance with negation detection", "Clause redlining suggestions", "OCR for scanned PDFs"],
|
| 41 |
},
|
| 42 |
{
|
| 43 |
name: "Pro", price: "999", period: "/mo", highlight: true, cta: "Start free trial",
|
| 44 |
+
features: ["Unlimited scans", "Upload PDF/DOCX files", "Contract comparison", "Q&A Chatbot (RAG)", "AI clause explanations", "LLM-refined redlining", "Scan history", "PDF report export", "Obligation tracker with priority", "Priority support"],
|
| 45 |
},
|
| 46 |
{
|
| 47 |
name: "Team", price: "3,999", period: "/mo", highlight: false, cta: "Talk to us",
|
|
|
|
| 64 |
<div className="max-w-2xl">
|
| 65 |
<div className="inline-flex items-center gap-2 px-3 py-1 rounded-full border border-zinc-200 text-[13px] text-zinc-500 mb-6">
|
| 66 |
<Sparkles className="w-3.5 h-3.5 text-zinc-400" />
|
| 67 |
+
6 AI models · 41 clause categories · RAG chatbot · clause redlining · OCR
|
| 68 |
</div>
|
| 69 |
<h1 className="text-3xl sm:text-[42px] lg:text-5xl font-semibold tracking-tight leading-[1.1]">
|
| 70 |
Know what you are<br className="hidden sm:block" /> agreeing to
|
| 71 |
</h1>
|
| 72 |
<p className="mt-5 text-base sm:text-[17px] text-zinc-500 leading-relaxed max-w-lg">
|
| 73 |
+
ClauseGuard scans contracts using 6 AI models. Get clause detection, risk scoring,
|
| 74 |
+
safer alternatives, Q&A chatbot, OCR for scanned PDFs, and compliance checks.
|
| 75 |
</p>
|
| 76 |
<div className="mt-8 flex flex-col sm:flex-row gap-3">
|
| 77 |
<Link href="/dashboard-pages/analyze" className="inline-flex items-center justify-center gap-2 bg-zinc-900 text-white px-5 py-2.5 rounded-lg text-sm font-medium hover:bg-zinc-800 transition-colors">
|
|
|
|
| 92 |
<ShieldCheck className="w-4 h-4 text-zinc-400" />
|
| 93 |
<p className="text-[13px] font-medium text-zinc-400 uppercase tracking-wider">Detection</p>
|
| 94 |
</div>
|
| 95 |
+
<h2 className="text-xl sm:text-2xl font-semibold tracking-tight">18 powerful analysis features</h2>
|
| 96 |
<p className="mt-2 text-zinc-500 text-sm sm:text-[15px] max-w-lg">
|
| 97 |
+
Based on the CUAD taxonomy + CLAUDETTE framework. Now with RAG chatbot, clause redlining, and OCR.
|
| 98 |
</p>
|
| 99 |
+
<div className="mt-8 sm:mt-10 grid grid-cols-2 sm:grid-cols-3 lg:grid-cols-4 gap-2 sm:gap-3">
|
| 100 |
{CLAUSES.map((c) => (
|
| 101 |
<div key={c.name} className="group border border-zinc-100 rounded-xl p-3 sm:p-4 hover:border-zinc-200 hover:shadow-sm transition-all cursor-default">
|
| 102 |
<div className={`w-7 h-7 sm:w-8 sm:h-8 rounded-lg flex items-center justify-center border ${sevColor[c.severity]}`}>
|
|
|
|
| 140 |
<Cpu className="w-4 h-4 text-zinc-400" />
|
| 141 |
<p className="text-[13px] font-medium text-zinc-400 uppercase tracking-wider">Technology</p>
|
| 142 |
</div>
|
| 143 |
+
<h2 className="text-xl sm:text-2xl font-semibold tracking-tight">Built on 6 production AI models</h2>
|
| 144 |
<div className="mt-8 grid sm:grid-cols-2 lg:grid-cols-3 gap-3 sm:gap-4">
|
| 145 |
{[
|
| 146 |
+
{ name: "Legal-BERT Classifier", icon: Cpu, desc: "LoRA fine-tuned on 41 CUAD categories with sigmoid multi-label classification", source: "Mokshith31/legalbert-contract-clause-classification" },
|
| 147 |
+
{ name: "Legal-BERT NER", icon: Tag, desc: "Named entity recognition for parties, dates, money, jurisdictions", source: "matterstack/legal-bert-ner" },
|
| 148 |
+
{ name: "DeBERTa-v3 NLI", icon: AlertTriangle, desc: "Semantic contradiction detection between clause pairs", source: "cross-encoder/nli-deberta-v3-base" },
|
| 149 |
+
{ name: "RAG Chatbot", icon: MessageSquare, desc: "Embedding retrieval + Qwen2.5-7B LLM for contract Q&A", source: "all-MiniLM-L6-v2 + Qwen/Qwen2.5-7B-Instruct" },
|
| 150 |
+
{ name: "Clause Redlining", icon: PenTool, desc: "18+ legal templates + LLM refinement for safer clause alternatives", source: "FTC/EU/CFPB standards + Qwen2.5-7B" },
|
| 151 |
+
{ name: "docTR OCR", icon: ScanLine, desc: "Smart PDF router: auto-detects scanned PDFs and extracts text", source: "docTR fast_base + crnn_vgg16_bn" },
|
| 152 |
].map((m) => (
|
| 153 |
<div key={m.name} className="border border-zinc-100 rounded-xl p-4 hover:border-zinc-200 hover:shadow-sm transition-all">
|
| 154 |
<div className="flex items-center gap-2 mb-2">
|
|
|
|
| 216 |
<div className="max-w-6xl mx-auto px-4 sm:px-6 py-8 flex flex-col sm:flex-row justify-between items-center gap-4">
|
| 217 |
<div className="flex items-center gap-2">
|
| 218 |
<ShieldCheck className="w-4 h-4 text-zinc-300" />
|
| 219 |
+
<span className="text-[13px] text-zinc-400">ClauseGuard v4.0 — not legal advice</span>
|
| 220 |
</div>
|
| 221 |
<div className="flex gap-5 text-[13px] text-zinc-400">
|
| 222 |
<Link href="/privacy" className="hover:text-zinc-600">Privacy</Link>
|
web/components/nav.tsx
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
|
| 3 |
import Link from "next/link";
|
| 4 |
import { usePathname } from "next/navigation";
|
| 5 |
-
import { ShieldCheck, Menu, X, Crown, GitCompare } from "lucide-react";
|
| 6 |
import { useState, useEffect } from "react";
|
| 7 |
import { createClient } from "@/lib/supabase/client";
|
| 8 |
|
|
@@ -27,7 +27,6 @@ export function Nav() {
|
|
| 27 |
const user = data.user;
|
| 28 |
setUserEmail(user?.email || null);
|
| 29 |
if (user) {
|
| 30 |
-
// Fetch role from database — no hardcoded emails
|
| 31 |
const { data: profile } = await supabase
|
| 32 |
.from("profiles")
|
| 33 |
.select("role")
|
|
@@ -44,7 +43,7 @@ export function Nav() {
|
|
| 44 |
<Link href="/" className="flex items-center gap-2">
|
| 45 |
<ShieldCheck className="w-5 h-5 text-zinc-900" strokeWidth={2.2} />
|
| 46 |
<span className="font-semibold text-[15px] tracking-tight text-zinc-900">ClauseGuard</span>
|
| 47 |
-
<span className="hidden sm:inline text-[10px] font-medium text-zinc-400 ml-1 border border-zinc-200 px-1.5 py-0.5 rounded">
|
| 48 |
</Link>
|
| 49 |
|
| 50 |
<div className="hidden md:flex items-center gap-1">
|
|
|
|
| 2 |
|
| 3 |
import Link from "next/link";
|
| 4 |
import { usePathname } from "next/navigation";
|
| 5 |
+
import { ShieldCheck, Menu, X, Crown, GitCompare, MessageSquare } from "lucide-react";
|
| 6 |
import { useState, useEffect } from "react";
|
| 7 |
import { createClient } from "@/lib/supabase/client";
|
| 8 |
|
|
|
|
| 27 |
const user = data.user;
|
| 28 |
setUserEmail(user?.email || null);
|
| 29 |
if (user) {
|
|
|
|
| 30 |
const { data: profile } = await supabase
|
| 31 |
.from("profiles")
|
| 32 |
.select("role")
|
|
|
|
| 43 |
<Link href="/" className="flex items-center gap-2">
|
| 44 |
<ShieldCheck className="w-5 h-5 text-zinc-900" strokeWidth={2.2} />
|
| 45 |
<span className="font-semibold text-[15px] tracking-tight text-zinc-900">ClauseGuard</span>
|
| 46 |
+
<span className="hidden sm:inline text-[10px] font-medium text-zinc-400 ml-1 border border-zinc-200 px-1.5 py-0.5 rounded">v4.0</span>
|
| 47 |
</Link>
|
| 48 |
|
| 49 |
<div className="hidden md:flex items-center gap-1">
|