anky2002 commited on
Commit
2ca6bab
·
2 Parent(s): 970b3d549d0c4a

Merge branch 'main' of https://huggingface.co/spaces/gaurv007/ClauseGuard

Browse files
README.md CHANGED
@@ -10,9 +10,17 @@ app_file: app.py
10
  pinned: false
11
  ---
12
 
13
- # 🛡️ ClauseGuard — World's Best Open-Source Legal Contract Analysis
14
 
15
- **ClauseGuard** is the most comprehensive open-source AI-powered legal contract analysis tool. It analyzes contracts using state-of-the-art legal NLP models and provides actionable risk assessments.
 
 
 
 
 
 
 
 
16
 
17
  ## ✨ Core Features
18
 
@@ -26,9 +34,12 @@ pinned: false
26
  | **Obligation Tracker** | Categorizes action items: monetary 💰, compliance ⚖️, reporting 📊, delivery 📦, termination 🛑 |
27
  | **Compliance Checker** | Validates against GDPR, CCPA, SOX, HIPAA, and FINRA requirements |
28
  | **Contract Comparison** | Side-by-side diff between two contracts with alignment scoring |
 
 
 
29
 
30
  ### Document Support
31
- - **PDF** parsing via `pdfplumber`
32
  - **DOCX/DOC** parsing via `python-docx`
33
  - **TXT / Markdown** direct text input
34
 
@@ -36,6 +47,8 @@ pinned: false
36
  - **3-Panel Professional Layout** — Upload sidebar + Main analysis + Summary dashboard
37
  - **Document Viewer** — Inline entity highlights (colored annotations)
38
  - **Clause Cards** — Expandable risk-badged cards with confidence scores
 
 
39
  - **Export Reports** — JSON (structured) and CSV (tabular) downloads
40
  - **Color-Coded Risk Badges** — Instant visual triage
41
 
@@ -44,12 +57,61 @@ pinned: false
44
  | Component | Technology |
45
  |-----------|------------|
46
  | Clause Classification | `Mokshith31/legalbert-contract-clause-classification` — LoRA adapter on `nlpaueb/legal-bert-base-uncased`, fine-tuned on CUAD 41-class taxonomy |
47
- | NER | Rule-based with 7 entity types (dates, money, parties, jurisdictions, defined terms) |
48
- | NLI | Heuristic contradiction detection with 5 conflict patterns + missing-clause detection |
 
 
 
49
  | Compliance | Regulatory keyword matching across GDPR, CCPA, SOX, HIPAA, FINRA |
50
- | Comparison | SequenceMatcher-based clause alignment with risk delta analysis |
51
  | Obligations | Regex pattern matching across 5 obligation categories |
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  ## 📊 Risk Scoring Methodology
54
 
55
  Risk scores combine clause detection with weighted severity:
@@ -65,16 +127,10 @@ Final score normalized to 0-100 with letter grades:
65
  - D (50-69): High risk
66
  - F (70+): Critical risk
67
 
68
- ## 📚 Datasets & Research
69
-
70
- - [CUAD](https://huggingface.co/datasets/theatticusproject/cuad-qa) — 510 contracts, 13K annotations, 41 clause categories
71
- - [LegalBench](https://huggingface.co/datasets/nguha/legalbench) — 322 legal reasoning tasks
72
- - [LexGLUE](https://huggingface.co/datasets/coastalcph/lex_glue) — Unfair Terms of Service classification
73
- - Paper: [CUAD: An Expert-Annotated NLP Dataset for Legal Contract Review](https://arxiv.org/abs/2103.06268) (Hendrycks et al., 2021)
74
-
75
  ## 🚀 Usage
76
 
77
  1. **Upload** a contract (PDF, DOCX, or TXT) or paste text directly
 
78
  2. Click **Analyze Contract**
79
  3. View results across tabs:
80
  - **Document**: Full text with inline entity highlights
@@ -83,7 +139,9 @@ Final score normalized to 0-100 with letter grades:
83
  - **Contradictions**: Conflicting clauses and missing provisions
84
  - **Obligations**: Action items categorized by type
85
  - **Compliance**: Regulatory framework checks
 
86
  4. **Export** JSON/CSV reports
 
87
 
88
  ## 🔀 Compare Contracts
89
 
@@ -91,7 +149,6 @@ Switch to the **Compare Contracts** tab to:
91
  - Upload or paste two contracts side-by-side
92
  - See clause-level diffs (added, removed, modified)
93
  - Get an alignment score and risk delta
94
- - View raw JSON comparison data
95
 
96
  ## ⚠️ Disclaimer
97
 
@@ -103,6 +160,8 @@ Switch to the **Compare Contracts** tab to:
103
  - [Clause Classifier Model](https://huggingface.co/Mokshith31/legalbert-contract-clause-classification)
104
  - [Legal-BERT Base](https://huggingface.co/nlpaueb/legal-bert-base-uncased)
105
  - [CUAD Dataset](https://huggingface.co/datasets/theatticusproject/cuad-qa)
 
 
106
  - [CUAD Paper (arXiv:2103.06268)](https://arxiv.org/abs/2103.06268)
107
 
108
  ---
 
10
  pinned: false
11
  ---
12
 
13
+ # 🛡️ ClauseGuard v4.0 — World's Best Open-Source Legal Contract Analysis
14
 
15
+ **ClauseGuard** is the most comprehensive open-source AI-powered legal contract analysis tool. It analyzes contracts using state-of-the-art legal NLP models and provides actionable risk assessments, Q&A chatbot, clause redlining, and OCR for scanned PDFs.
16
+
17
+ ## 🆕 What's New in v4.0
18
+
19
+ | Feature | Description |
20
+ |---------|-------------|
21
+ | **🔍 OCR for Scanned PDFs** | Smart PDF router: auto-detects native vs scanned PDFs. Scanned PDFs are processed via docTR OCR engine (CPU-friendly, ~150MB models) |
22
+ | **💬 Contract Q&A Chatbot** | RAG-powered chatbot that answers questions about your analyzed contract. Uses sentence-transformers for retrieval + Qwen2.5-7B via HF Inference API for generation |
23
+ | **✏️ Clause Redlining** | 3-tier system: (1) Template lookup from 18+ legal templates based on FTC/EU standards, (2) Keyword-based matching, (3) LLM refinement for CRITICAL/HIGH risk clauses |
24
 
25
  ## ✨ Core Features
26
 
 
34
  | **Obligation Tracker** | Categorizes action items: monetary 💰, compliance ⚖️, reporting 📊, delivery 📦, termination 🛑 |
35
  | **Compliance Checker** | Validates against GDPR, CCPA, SOX, HIPAA, and FINRA requirements |
36
  | **Contract Comparison** | Side-by-side diff between two contracts with alignment scoring |
37
+ | **Clause Redlining** | Suggests safer alternatives for risky clauses with legal citations |
38
+ | **Q&A Chatbot** | Ask questions about your contract using RAG (Retrieval-Augmented Generation) |
39
+ | **OCR Support** | Process scanned PDFs with docTR OCR engine |
40
 
41
  ### Document Support
42
+ - **PDF** parsing via `pdfplumber` (native) + `docTR` OCR (scanned)
43
  - **DOCX/DOC** parsing via `python-docx`
44
  - **TXT / Markdown** direct text input
45
 
 
47
  - **3-Panel Professional Layout** — Upload sidebar + Main analysis + Summary dashboard
48
  - **Document Viewer** — Inline entity highlights (colored annotations)
49
  - **Clause Cards** — Expandable risk-badged cards with confidence scores
50
+ - **Redlining Tab** — Side-by-side original vs suggested safer alternatives
51
+ - **Q&A Chat Tab** — Conversational interface to ask questions about the contract
52
  - **Export Reports** — JSON (structured) and CSV (tabular) downloads
53
  - **Color-Coded Risk Badges** — Instant visual triage
54
 
 
57
  | Component | Technology |
58
  |-----------|------------|
59
  | Clause Classification | `Mokshith31/legalbert-contract-clause-classification` — LoRA adapter on `nlpaueb/legal-bert-base-uncased`, fine-tuned on CUAD 41-class taxonomy |
60
+ | Legal NER | `matterstack/legal-bert-ner` (ML) with regex fallback for 7 entity types |
61
+ | NLI | `cross-encoder/nli-deberta-v3-base` (semantic contradiction detection) |
62
+ | Embeddings | `sentence-transformers/all-MiniLM-L6-v2` (384-dim, RAG retrieval) |
63
+ | LLM | `Qwen/Qwen2.5-7B-Instruct` via HF Inference API (chatbot + redlining) |
64
+ | OCR | `docTR` (fast_base + crnn_vgg16_bn) for scanned PDF text extraction |
65
  | Compliance | Regulatory keyword matching across GDPR, CCPA, SOX, HIPAA, FINRA |
66
+ | Comparison | Semantic similarity with sentence embeddings + string matching fallback |
67
  | Obligations | Regex pattern matching across 5 obligation categories |
68
 
69
+ ## 🔍 OCR Architecture (Smart PDF Router)
70
+
71
+ ```
72
+ PDF uploaded
73
+
74
+ [detect_if_scanned] — pdfplumber extracts >50 chars/page?
75
+ ↓ ↓
76
+ Native PDF Scanned PDF
77
+ ↓ ↓
78
+ pdfplumber docTR OCR (CPU)
79
+ ↓ ↓
80
+ Contract text → existing analysis pipeline
81
+ ```
82
+
83
+ ## 💬 Q&A Chatbot Architecture (RAG)
84
+
85
+ ```
86
+ User asks question about their contract
87
+
88
+ [1] Embed question with all-MiniLM-L6-v2
89
+
90
+ [2] Retrieve top-5 most relevant chunks from contract
91
+
92
+ [3] Build prompt:
93
+ - System: ClauseGuard analysis results (clauses, entities, risk scores)
94
+ - Context: Retrieved contract chunks (≤2.5K tokens)
95
+ - User question
96
+
97
+ [4] Stream response from Qwen2.5-7B via HF Inference API
98
+ ```
99
+
100
+ **Key design:** Analyzed data (clauses, entities, risk scores) goes in the system prompt — NOT through RAG retrieval. Only the raw contract text goes through RAG. This gives the model both structured analysis AND verbatim evidence.
101
+
102
+ ## ✏️ Clause Redlining Architecture (3-Tier)
103
+
104
+ | Tier | Method | Speed | Hallucination Risk |
105
+ |------|--------|-------|--------------------|
106
+ | **1. Template Lookup** | 18+ pre-written safe alternatives based on FTC/EU/CFPB standards | Instant | Zero |
107
+ | **2. Keyword Matching** | Match clause text to relevant templates via legal keywords | Instant | Zero |
108
+ | **3. LLM Refinement** | Qwen2.5-7B adapts template to specific clause context | ~3-5s | Low (template-anchored) |
109
+
110
+ Anti-hallucination guardrails:
111
+ - **Template anchor:** LLM can only refine, not generate from scratch
112
+ - **Legal citation:** Every suggestion includes legal basis and consumer standard
113
+ - **Disclaimer:** Clear "Not legal advice" warning
114
+
115
  ## 📊 Risk Scoring Methodology
116
 
117
  Risk scores combine clause detection with weighted severity:
 
127
  - D (50-69): High risk
128
  - F (70+): Critical risk
129
 
 
 
 
 
 
 
 
130
  ## 🚀 Usage
131
 
132
  1. **Upload** a contract (PDF, DOCX, or TXT) or paste text directly
133
+ - 💡 Scanned PDFs are automatically processed with OCR
134
  2. Click **Analyze Contract**
135
  3. View results across tabs:
136
  - **Document**: Full text with inline entity highlights
 
139
  - **Contradictions**: Conflicting clauses and missing provisions
140
  - **Obligations**: Action items categorized by type
141
  - **Compliance**: Regulatory framework checks
142
+ - **Redlining**: ✏️ Safer clause alternatives with legal citations
143
  4. **Export** JSON/CSV reports
144
+ 5. Switch to **💬 Contract Q&A** tab to ask questions about your contract
145
 
146
  ## 🔀 Compare Contracts
147
 
 
149
  - Upload or paste two contracts side-by-side
150
  - See clause-level diffs (added, removed, modified)
151
  - Get an alignment score and risk delta
 
152
 
153
  ## ⚠️ Disclaimer
154
 
 
160
  - [Clause Classifier Model](https://huggingface.co/Mokshith31/legalbert-contract-clause-classification)
161
  - [Legal-BERT Base](https://huggingface.co/nlpaueb/legal-bert-base-uncased)
162
  - [CUAD Dataset](https://huggingface.co/datasets/theatticusproject/cuad-qa)
163
+ - [Qwen2.5-7B (Chatbot LLM)](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct)
164
+ - [docTR OCR](https://github.com/mindee/doctr)
165
  - [CUAD Paper (arXiv:2103.06268)](https://arxiv.org/abs/2103.06268)
166
 
167
  ---
api/Dockerfile CHANGED
@@ -2,10 +2,16 @@ FROM python:3.12-slim
2
 
3
  WORKDIR /app
4
 
5
- COPY requirements.txt .
 
6
  RUN pip install --no-cache-dir -r requirements.txt
7
 
8
- COPY . .
 
 
 
 
 
9
 
10
  EXPOSE 8000
11
 
 
2
 
3
  WORKDIR /app
4
 
5
+ # Install api dependencies
6
+ COPY api/requirements.txt ./requirements.txt
7
  RUN pip install --no-cache-dir -r requirements.txt
8
 
9
+ # Copy shared modules from root (needed by api/main.py)
10
+ COPY app.py compare.py compliance.py obligations.py ./
11
+ COPY ocr_engine.py chatbot.py redlining.py ./
12
+
13
+ # Copy api files
14
+ COPY api/ ./
15
 
16
  EXPOSE 8000
17
 
api/main.py CHANGED
@@ -1,19 +1,19 @@
1
  """
2
- ClauseGuard — FastAPI Backend v3.0
3
  ══════════════════════════════════
4
- FIXED in v3.0:
5
- Imports shared modules (no code duplication)
6
- Fixed API schema to accept both {text} and {clauses} from extension
7
- Added rate limiting
8
- Added max text length validation
9
- • Fixed CORS (removed wildcard)
10
- • Added proper error responses
11
  """
12
 
13
  import os
14
  import re
15
  import json
16
  import time
 
 
17
  from contextlib import asynccontextmanager
18
  from typing import Optional
19
  from collections import defaultdict
@@ -21,14 +21,14 @@ from datetime import datetime
21
 
22
  import httpx
23
  import numpy as np
24
- from fastapi import FastAPI, HTTPException, Depends, Body, Request
25
  from fastapi.middleware.cors import CORSMiddleware
 
26
  from pydantic import BaseModel, Field
27
 
28
  from auth import get_current_user, require_auth
29
 
30
  # ── Import shared modules ──
31
- # When deployed, these must be in the same directory or on PYTHONPATH
32
  import sys
33
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
34
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -36,29 +36,32 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
36
  try:
37
  from app import (
38
  split_clauses, classify_cuad, extract_entities,
39
- detect_contradictions, compute_risk_score,
40
  CUAD_LABELS, RISK_MAP, DESC_MAP, _model_status,
41
  cuad_model, cuad_tokenizer
42
  )
43
  from obligations import extract_obligations
44
  from compliance import check_compliance
45
  from compare import compare_contracts
 
 
 
46
  _SHARED_MODULES = True
47
- except ImportError:
48
  _SHARED_MODULES = False
49
- print("[API] WARNING: Could not import shared modules, using inline fallbacks")
50
 
51
  # ─── Config ───
52
  SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
53
  SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "")
54
  HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "")
55
  SAULLM_ENDPOINT = os.environ.get("SAULLM_ENDPOINT", "")
56
- MAX_TEXT_LENGTH = int(os.environ.get("MAX_TEXT_LENGTH", "100000")) # 100KB default
57
 
58
  # ─── Rate Limiting ───
59
- _rate_limits = {} # ip -> (count, window_start)
60
  RATE_LIMIT_REQUESTS = 30
61
- RATE_LIMIT_WINDOW = 60 # seconds
62
 
63
  def _check_rate_limit(client_ip: str) -> bool:
64
  now = time.time()
@@ -113,25 +116,16 @@ async def supabase_query(table: str, params: dict, headers_extra: dict = {}):
113
  except Exception:
114
  return []
115
 
 
 
 
 
116
  # ─── Request/Response Models ───
117
  class AnalyzeRequest(BaseModel):
118
  text: Optional[str] = Field(None, min_length=50)
119
- clauses: Optional[list] = None # FIXED: accept clauses array from extension
120
  source_url: Optional[str] = None
121
 
122
- class AnalyzeResponse(BaseModel):
123
- risk_score: int
124
- grade: str
125
- total_clauses: int
126
- flagged_count: int
127
- results: list[dict]
128
- entities: list[dict]
129
- contradictions: list[dict]
130
- obligations: list[dict]
131
- compliance: dict
132
- model: str
133
- latency_ms: int
134
-
135
  class CompareRequest(BaseModel):
136
  text_a: str = Field(..., min_length=50)
137
  text_b: str = Field(..., min_length=50)
@@ -147,21 +141,28 @@ class ExplainResponse(BaseModel):
147
  legal_basis: str
148
  recommendation: str
149
 
 
 
 
 
 
 
 
 
 
 
150
  # ─── App ───
151
  @asynccontextmanager
152
  async def lifespan(app: FastAPI):
153
- # Models are loaded when app.py is imported
154
  yield
155
 
156
- app = FastAPI(title="ClauseGuard API", version="3.0.0", lifespan=lifespan)
157
 
158
- # FIXED: No wildcard CORS
159
  ALLOWED_ORIGINS = [
160
  "https://clauseguardweb.netlify.app",
161
  "http://localhost:3000",
162
  "http://localhost:3001",
163
  ]
164
- # Allow chrome extensions
165
  app.add_middleware(
166
  CORSMiddleware,
167
  allow_origins=ALLOWED_ORIGINS,
@@ -174,36 +175,36 @@ app.add_middleware(
174
  @app.get("/health")
175
  async def health():
176
  model_status = "ml" if _SHARED_MODULES and cuad_model else "regex"
 
177
  return {
178
  "status": "ok",
179
  "model": model_status,
180
- "version": "3.0.0",
181
  "shared_modules": _SHARED_MODULES,
 
 
182
  }
183
 
184
- @app.post("/api/analyze", response_model=AnalyzeResponse)
185
  async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] = Depends(get_current_user)):
186
- # Rate limiting
187
  client_ip = request.client.host if request.client else "unknown"
188
  if not _check_rate_limit(client_ip):
189
- raise HTTPException(status_code=429, detail="Rate limit exceeded. Try again in 60 seconds.")
190
 
191
- # FIXED: Accept either text or clauses from extension
192
  text = req.text
193
  if not text and req.clauses:
194
  text = "\n\n".join(req.clauses) if isinstance(req.clauses, list) else str(req.clauses)
195
 
196
  if not text or len(text.strip()) < 50:
197
  raise HTTPException(status_code=400, detail="Text too short (minimum 50 characters)")
198
-
199
- # Max length check
200
  if len(text) > MAX_TEXT_LENGTH:
201
- raise HTTPException(status_code=400, detail=f"Text too long (maximum {MAX_TEXT_LENGTH} characters)")
202
 
203
  start = time.time()
 
204
  clauses = split_clauses(text)
205
  if not clauses:
206
- raise HTTPException(status_code=400, detail="No clauses detected in document")
207
 
208
  clause_results = []
209
  for clause in clauses:
@@ -224,6 +225,15 @@ async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] =
224
  risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses))
225
  obligations = extract_obligations(text)
226
  compliance = check_compliance(text)
 
 
 
 
 
 
 
 
 
227
  latency = int((time.time() - start) * 1000)
228
 
229
  results_for_db = []
@@ -238,6 +248,29 @@ async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] =
238
  }],
239
  })
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  if user:
242
  await supabase_insert("analyses", {
243
  "user_id": user["id"],
@@ -253,46 +286,120 @@ async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] =
253
  "compliance": compliance,
254
  })
255
 
256
- return AnalyzeResponse(
257
- risk_score=risk,
258
- grade=grade,
259
- total_clauses=len(clauses),
260
- flagged_count=len(set(cr["text"] for cr in clause_results)),
261
- results=results_for_db,
262
- entities=entities,
263
- contradictions=contradictions,
264
- obligations=obligations,
265
- compliance=compliance,
266
- model="ml" if cuad_model else "regex",
267
- latency_ms=latency,
268
- )
 
 
269
 
270
  @app.post("/api/compare")
271
  async def compare(req: CompareRequest, request: Request):
272
  client_ip = request.client.host if request.client else "unknown"
273
  if not _check_rate_limit(client_ip):
274
  raise HTTPException(status_code=429, detail="Rate limit exceeded.")
275
- result = compare_contracts(req.text_a, req.text_b)
276
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  @app.post("/api/explain", response_model=ExplainResponse)
279
  async def explain(req: ExplainRequest, user: dict = Depends(require_auth)):
280
  desc = DESC_MAP.get(req.category, "Unknown category.")
281
  legal = "Consult local consumer protection laws."
282
- recommendation = "Review this clause carefully. Consider negotiating or seeking legal advice before agreeing."
283
 
284
  if SAULLM_ENDPOINT and HF_API_TOKEN:
285
  try:
286
  prompt = (
287
- f"You are a consumer protection legal analyst. Analyze this contract clause "
288
- f"and explain why it may be unfair or risky.\n\n"
289
- f"Clause: \"{req.clause}\"\n"
290
- f"Category: {req.category}\n\n"
291
- f"Provide:\n"
292
- f"1. A plain-English explanation of what this clause means\n"
293
- f"2. The specific legal basis or consumer protection concern\n"
294
- f"3. A practical recommendation\n\n"
295
- f"Be concise. 3-4 sentences per section."
296
  )
297
  async with httpx.AsyncClient(timeout=30.0) as client:
298
  resp = await client.post(
@@ -311,27 +418,16 @@ async def explain(req: ExplainRequest, user: dict = Depends(require_auth)):
311
  except Exception:
312
  pass
313
 
314
- return ExplainResponse(
315
- clause=req.clause,
316
- category=req.category,
317
- explanation=desc,
318
- legal_basis=legal,
319
- recommendation=recommendation,
320
- )
321
 
322
  @app.get("/api/history")
323
  async def history(user: dict = Depends(require_auth), limit: int = 20, offset: int = 0):
324
  limit = min(limit, 100)
325
- data = await supabase_query(
326
- "analyses",
327
- {
328
- "user_id": f"eq.{user['id']}",
329
- "select": "*",
330
- "order": "created_at.desc",
331
- "limit": str(limit),
332
- "offset": str(offset),
333
- },
334
- )
335
  return {"analyses": data, "limit": limit, "offset": offset}
336
 
337
  if __name__ == "__main__":
 
1
  """
2
+ ClauseGuard — FastAPI Backend v4.0
3
  ══════════════════════════════════
4
+ New in v4.0:
5
+ /api/redline clause redlining suggestions
6
+ /api/chat RAG chatbot (streaming)
7
+ /api/ocr OCR scanned PDF extraction
8
+ Updated analysis to include redlining data
 
 
9
  """
10
 
11
  import os
12
  import re
13
  import json
14
  import time
15
+ import uuid
16
+ import tempfile
17
  from contextlib import asynccontextmanager
18
  from typing import Optional
19
  from collections import defaultdict
 
21
 
22
  import httpx
23
  import numpy as np
24
+ from fastapi import FastAPI, HTTPException, Depends, Body, Request, UploadFile, File as FastAPIFile
25
  from fastapi.middleware.cors import CORSMiddleware
26
+ from fastapi.responses import StreamingResponse
27
  from pydantic import BaseModel, Field
28
 
29
  from auth import get_current_user, require_auth
30
 
31
  # ── Import shared modules ──
 
32
  import sys
33
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
34
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
36
  try:
37
  from app import (
38
  split_clauses, classify_cuad, extract_entities,
39
+ detect_contradictions, compute_risk_score, analyze_contract,
40
  CUAD_LABELS, RISK_MAP, DESC_MAP, _model_status,
41
  cuad_model, cuad_tokenizer
42
  )
43
  from obligations import extract_obligations
44
  from compliance import check_compliance
45
  from compare import compare_contracts
46
+ from redlining import generate_redlines
47
+ from chatbot import index_contract, chat_respond
48
+ from ocr_engine import parse_pdf_smart, get_ocr_status
49
  _SHARED_MODULES = True
50
+ except ImportError as e:
51
  _SHARED_MODULES = False
52
+ print(f"[API] WARNING: Could not import shared modules: {e}")
53
 
54
  # ─── Config ───
55
  SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
56
  SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "")
57
  HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "")
58
  SAULLM_ENDPOINT = os.environ.get("SAULLM_ENDPOINT", "")
59
+ MAX_TEXT_LENGTH = int(os.environ.get("MAX_TEXT_LENGTH", "100000"))
60
 
61
  # ─── Rate Limiting ───
62
+ _rate_limits = {}
63
  RATE_LIMIT_REQUESTS = 30
64
+ RATE_LIMIT_WINDOW = 60
65
 
66
  def _check_rate_limit(client_ip: str) -> bool:
67
  now = time.time()
 
116
  except Exception:
117
  return []
118
 
119
+ # ─── In-memory RAG session store ───
120
+ _rag_sessions: dict = {}
121
+ _RAG_SESSION_MAX = 100
122
+
123
  # ─── Request/Response Models ───
124
  class AnalyzeRequest(BaseModel):
125
  text: Optional[str] = Field(None, min_length=50)
126
+ clauses: Optional[list] = None
127
  source_url: Optional[str] = None
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  class CompareRequest(BaseModel):
130
  text_a: str = Field(..., min_length=50)
131
  text_b: str = Field(..., min_length=50)
 
141
  legal_basis: str
142
  recommendation: str
143
 
144
+ class ChatRequest(BaseModel):
145
+ message: str = Field(..., min_length=1, max_length=2000)
146
+ session_id: str
147
+ history: Optional[list[dict]] = None
148
+
149
+ class RedlineRequest(BaseModel):
150
+ session_id: Optional[str] = None
151
+ text: Optional[str] = None
152
+ use_llm: bool = True
153
+
154
  # ─── App ───
155
  @asynccontextmanager
156
  async def lifespan(app: FastAPI):
 
157
  yield
158
 
159
+ app = FastAPI(title="ClauseGuard API", version="4.0.0", lifespan=lifespan)
160
 
 
161
  ALLOWED_ORIGINS = [
162
  "https://clauseguardweb.netlify.app",
163
  "http://localhost:3000",
164
  "http://localhost:3001",
165
  ]
 
166
  app.add_middleware(
167
  CORSMiddleware,
168
  allow_origins=ALLOWED_ORIGINS,
 
175
  @app.get("/health")
176
  async def health():
177
  model_status = "ml" if _SHARED_MODULES and cuad_model else "regex"
178
+ ocr_status = get_ocr_status() if _SHARED_MODULES else "unavailable"
179
  return {
180
  "status": "ok",
181
  "model": model_status,
182
+ "version": "4.0.0",
183
  "shared_modules": _SHARED_MODULES,
184
+ "ocr": ocr_status,
185
+ "features": ["analyze", "compare", "redline", "chat", "ocr"],
186
  }
187
 
188
+ @app.post("/api/analyze")
189
  async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] = Depends(get_current_user)):
 
190
  client_ip = request.client.host if request.client else "unknown"
191
  if not _check_rate_limit(client_ip):
192
+ raise HTTPException(status_code=429, detail="Rate limit exceeded.")
193
 
 
194
  text = req.text
195
  if not text and req.clauses:
196
  text = "\n\n".join(req.clauses) if isinstance(req.clauses, list) else str(req.clauses)
197
 
198
  if not text or len(text.strip()) < 50:
199
  raise HTTPException(status_code=400, detail="Text too short (minimum 50 characters)")
 
 
200
  if len(text) > MAX_TEXT_LENGTH:
201
+ raise HTTPException(status_code=400, detail=f"Text too long (max {MAX_TEXT_LENGTH} chars)")
202
 
203
  start = time.time()
204
+
205
  clauses = split_clauses(text)
206
  if not clauses:
207
+ raise HTTPException(status_code=400, detail="No clauses detected")
208
 
209
  clause_results = []
210
  for clause in clauses:
 
225
  risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses))
226
  obligations = extract_obligations(text)
227
  compliance = check_compliance(text)
228
+
229
+ # v4.0: Redlining
230
+ analysis_for_redline = {"clauses": clause_results}
231
+ redlines = []
232
+ try:
233
+ redlines = generate_redlines(analysis_for_redline, use_llm=True)
234
+ except Exception as e:
235
+ print(f"[API] Redlining error: {e}")
236
+
237
  latency = int((time.time() - start) * 1000)
238
 
239
  results_for_db = []
 
248
  }],
249
  })
250
 
251
+ # v4.0: RAG indexing
252
+ session_id = None
253
+ try:
254
+ chunks, embeddings, _status = index_contract(text)
255
+ if chunks and embeddings is not None:
256
+ session_id = uuid.uuid4().hex[:12]
257
+ if len(_rag_sessions) >= _RAG_SESSION_MAX:
258
+ oldest = next(iter(_rag_sessions))
259
+ del _rag_sessions[oldest]
260
+ _rag_sessions[session_id] = {
261
+ "chunks": chunks,
262
+ "embeddings": embeddings,
263
+ "analysis": {
264
+ "risk": {"score": risk, "grade": grade, "breakdown": sev_counts},
265
+ "metadata": {"total_clauses": len(clauses), "flagged_clauses": len(clause_results)},
266
+ "clauses": clause_results[:30],
267
+ "entities": entities[:30],
268
+ "contradictions": contradictions,
269
+ },
270
+ }
271
+ except Exception as e:
272
+ print(f"[API] RAG indexing error: {e}")
273
+
274
  if user:
275
  await supabase_insert("analyses", {
276
  "user_id": user["id"],
 
286
  "compliance": compliance,
287
  })
288
 
289
+ return {
290
+ "risk_score": risk,
291
+ "grade": grade,
292
+ "total_clauses": len(clauses),
293
+ "flagged_count": len(set(cr["text"] for cr in clause_results)),
294
+ "results": results_for_db,
295
+ "entities": entities,
296
+ "contradictions": contradictions,
297
+ "obligations": obligations,
298
+ "compliance": compliance,
299
+ "redlines": redlines,
300
+ "model": "ml" if cuad_model else "regex",
301
+ "latency_ms": latency,
302
+ "session_id": session_id,
303
+ }
304
 
305
  @app.post("/api/compare")
306
  async def compare(req: CompareRequest, request: Request):
307
  client_ip = request.client.host if request.client else "unknown"
308
  if not _check_rate_limit(client_ip):
309
  raise HTTPException(status_code=429, detail="Rate limit exceeded.")
310
+ return compare_contracts(req.text_a, req.text_b)
311
+
312
+ @app.post("/api/redline")
313
+ async def redline(req: RedlineRequest, request: Request):
314
+ client_ip = request.client.host if request.client else "unknown"
315
+ if not _check_rate_limit(client_ip):
316
+ raise HTTPException(status_code=429, detail="Rate limit exceeded.")
317
+
318
+ if req.session_id and req.session_id in _rag_sessions:
319
+ analysis = _rag_sessions[req.session_id]["analysis"]
320
+ elif req.text:
321
+ result, error = analyze_contract(req.text)
322
+ if error:
323
+ raise HTTPException(status_code=400, detail=error)
324
+ analysis = result
325
+ else:
326
+ raise HTTPException(status_code=400, detail="Provide session_id or text")
327
+
328
+ redlines = generate_redlines(analysis, use_llm=req.use_llm)
329
+ return {"redlines": redlines, "count": len(redlines)}
330
+
331
+ @app.post("/api/chat")
332
+ async def chat(req: ChatRequest, request: Request):
333
+ client_ip = request.client.host if request.client else "unknown"
334
+ if not _check_rate_limit(client_ip):
335
+ raise HTTPException(status_code=429, detail="Rate limit exceeded.")
336
+
337
+ if req.session_id not in _rag_sessions:
338
+ raise HTTPException(status_code=404, detail="Session not found. Analyze a contract first.")
339
+
340
+ session = _rag_sessions[req.session_id]
341
+ response_text = ""
342
+ for partial in chat_respond(req.message, req.history or [],
343
+ session["chunks"], session["embeddings"], session["analysis"]):
344
+ response_text = partial
345
+
346
+ return {"response": response_text, "session_id": req.session_id}
347
+
348
+ @app.post("/api/chat/stream")
349
+ async def chat_stream(req: ChatRequest, request: Request):
350
+ client_ip = request.client.host if request.client else "unknown"
351
+ if not _check_rate_limit(client_ip):
352
+ raise HTTPException(status_code=429, detail="Rate limit exceeded.")
353
+
354
+ if req.session_id not in _rag_sessions:
355
+ raise HTTPException(status_code=404, detail="Session not found.")
356
+
357
+ session = _rag_sessions[req.session_id]
358
+
359
+ async def generate():
360
+ last = ""
361
+ for partial in chat_respond(
362
+ req.message, req.history or [],
363
+ session["chunks"], session["embeddings"], session["analysis"]
364
+ ):
365
+ delta = partial[len(last):]
366
+ last = partial
367
+ if delta:
368
+ yield f"data: {json.dumps({'delta': delta})}\n\n"
369
+ yield "data: [DONE]\n\n"
370
+
371
+ return StreamingResponse(generate(), media_type="text/event-stream")
372
+
373
+ @app.post("/api/ocr")
374
+ async def ocr_endpoint(file: UploadFile = FastAPIFile(...)):
375
+ if not file.filename or not file.filename.lower().endswith(".pdf"):
376
+ raise HTTPException(status_code=400, detail="Only PDF files supported")
377
+
378
+ with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
379
+ content = await file.read()
380
+ tmp.write(content)
381
+ tmp_path = tmp.name
382
+
383
+ try:
384
+ text, error, method = parse_pdf_smart(tmp_path)
385
+ if error:
386
+ raise HTTPException(status_code=400, detail=error)
387
+ return {"text": text, "method": method, "chars": len(text) if text else 0, "filename": file.filename}
388
+ finally:
389
+ os.unlink(tmp_path)
390
 
391
  @app.post("/api/explain", response_model=ExplainResponse)
392
  async def explain(req: ExplainRequest, user: dict = Depends(require_auth)):
393
  desc = DESC_MAP.get(req.category, "Unknown category.")
394
  legal = "Consult local consumer protection laws."
395
+ recommendation = "Review this clause carefully."
396
 
397
  if SAULLM_ENDPOINT and HF_API_TOKEN:
398
  try:
399
  prompt = (
400
+ f"Analyze this contract clause and explain why it may be risky.\n\n"
401
+ f"Clause: \"{req.clause}\"\nCategory: {req.category}\n\n"
402
+ f"Provide: 1) Plain-English explanation 2) Legal basis 3) Recommendation"
 
 
 
 
 
 
403
  )
404
  async with httpx.AsyncClient(timeout=30.0) as client:
405
  resp = await client.post(
 
418
  except Exception:
419
  pass
420
 
421
+ return ExplainResponse(clause=req.clause, category=req.category,
422
+ explanation=desc, legal_basis=legal, recommendation=recommendation)
 
 
 
 
 
423
 
424
  @app.get("/api/history")
425
  async def history(user: dict = Depends(require_auth), limit: int = 20, offset: int = 0):
426
  limit = min(limit, 100)
427
+ data = await supabase_query("analyses", {
428
+ "user_id": f"eq.{user['id']}", "select": "*",
429
+ "order": "created_at.desc", "limit": str(limit), "offset": str(offset),
430
+ })
 
 
 
 
 
 
431
  return {"analyses": data, "limit": limit, "offset": offset}
432
 
433
  if __name__ == "__main__":
api/requirements.txt CHANGED
@@ -1,10 +1,13 @@
1
- fastapi>=0.136.0
2
- uvicorn[standard]>=0.46.0
3
- pydantic>=2.13.3
4
- transformers>=5.6.1
5
  numpy>=2.0.0
6
  python-jose[cryptography]>=3.3.0
7
  httpx>=0.28.0
8
  peft>=0.15.0
9
  torch>=2.5.0
10
  sentence-transformers>=3.0.0
 
 
 
 
1
+ fastapi>=0.115.0
2
+ uvicorn[standard]>=0.34.0
3
+ pydantic>=2.10.0
4
+ transformers>=4.45.0
5
  numpy>=2.0.0
6
  python-jose[cryptography]>=3.3.0
7
  httpx>=0.28.0
8
  peft>=0.15.0
9
  torch>=2.5.0
10
  sentence-transformers>=3.0.0
11
+ python-doctr[torch]>=0.9.0
12
+ huggingface_hub>=0.25.0
13
+ python-multipart>=0.0.7
app.py CHANGED
@@ -1,7 +1,12 @@
1
  """
2
- ClauseGuard — World's Best Legal Contract Analysis Tool (v3.0)
3
  ═══════════════════════════════════════════════════════════════
4
- Fixes in v3.0:
 
 
 
 
 
5
  • Fixed CUAD label mapping (added missing index 6: "Notice Period to Terminate Renewal")
6
  • Switched from softmax → sigmoid for proper multi-label classification
7
  • Per-class optimized thresholds instead of flat 0.15
@@ -21,6 +26,9 @@ Models:
21
  (LoRA adapter on nlpaueb/legal-bert-base-uncased, 41 CUAD classes)
22
  • Legal NER: matterstack/legal-bert-ner (token classification)
23
  • NLI: cross-encoder/nli-deberta-v3-base (contradiction detection)
 
 
 
24
  """
25
 
26
  import os
@@ -71,6 +79,9 @@ except Exception:
71
  from compare import compare_contracts, render_comparison_html
72
  from obligations import extract_obligations, render_obligations_html
73
  from compliance import check_compliance, render_compliance_html
 
 
 
74
 
75
  # ═══════════════════════════════════════════════════════════════════════
76
  # 1. CONFIGURATION — FIXED label mapping (41 labels, index 6 restored)
@@ -335,20 +346,15 @@ _load_nli_model()
335
  # ═══════════════════════════════════════════════════════════════════════
336
 
337
  def parse_pdf(file_path):
338
- if not _HAS_PDF:
339
- return None, "PDF parsing not available (pdfplumber not installed)"
340
- try:
341
- text = ""
342
- with pdfplumber.open(file_path) as pdf:
343
- for page in pdf.pages:
344
- page_text = page.extract_text()
345
- if page_text:
346
- text += page_text + "\n\n"
347
- if not text.strip():
348
- return None, "PDF appears to be scanned/image-based. OCR is not yet supported. Please use a digital PDF or paste text directly."
349
- return text.strip(), None
350
- except Exception as e:
351
- return None, f"PDF parse error: {e}"
352
 
353
  def parse_docx(file_path):
354
  if not _HAS_DOCX:
@@ -378,11 +384,22 @@ def parse_document(file_path):
378
  return None, f"Unsupported file type: {ext}"
379
 
380
  # ═══════════════════════════════════════════════════════════════════════
381
- # 4. STRUCTURE-AWARE CLAUSE SPLITTING
382
  # ═══════════════════════════════════════════════════════════════════════
383
 
 
 
 
384
  def split_clauses(text):
385
- """Structure-aware clause splitting that respects section numbering."""
 
 
 
 
 
 
 
 
386
  text = re.sub(r'\n{3,}', '\n\n', text.strip())
387
 
388
  # First try to detect numbered sections (1., 2., 3.1, (a), etc.)
@@ -426,9 +443,13 @@ def split_clauses(text):
426
  preamble = text[:positions[0]].strip()
427
  if len(preamble) > 30:
428
  clauses.insert(0, preamble)
429
- return clauses if clauses else _fallback_split(text)
 
 
430
  else:
431
- return _fallback_split(text)
 
 
432
 
433
  def _fallback_split(text):
434
  """Fallback: split on paragraph breaks and sentence boundaries."""
@@ -462,8 +483,40 @@ def _fallback_split(text):
462
 
463
  # ═══════════════════════════════════════════════════════════════════════
464
  # 5. CLAUSE DETECTION — FIXED: sigmoid + per-class thresholds + caching
 
 
465
  # ═══════════════════════════════════════════════════════════════════════
466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  def _text_hash(text):
468
  return hashlib.md5(text.encode()).hexdigest()
469
 
@@ -474,14 +527,17 @@ def classify_cuad(clause_text):
474
  if cuad_model is None or cuad_tokenizer is None:
475
  return _classify_regex(clause_text)
476
 
 
 
 
477
  # Check cache
478
- h = _text_hash(clause_text[:512])
479
  if h in _prediction_cache:
480
  return _prediction_cache[h]
481
 
482
  try:
483
  inputs = cuad_tokenizer(
484
- clause_text,
485
  return_tensors="pt",
486
  truncation=True,
487
  max_length=256,
@@ -498,10 +554,15 @@ def classify_cuad(clause_text):
498
  threshold = _CUAD_THRESHOLDS.get(i, 0.40)
499
  if float(prob) > threshold and i < len(CUAD_LABELS):
500
  label = CUAD_LABELS[i]
 
 
 
 
 
501
  risk = RISK_MAP.get(label, "LOW")
502
  results.append({
503
  "label": label,
504
- "confidence": round(float(prob), 3),
505
  "risk": risk,
506
  "description": DESC_MAP.get(label, label),
507
  "source": "ml",
@@ -773,19 +834,33 @@ def detect_contradictions(clause_results, raw_text=""):
773
  "source": "heuristic",
774
  })
775
 
776
- # ── 2. Missing critical clauses ──
777
- critical_clauses = {
778
- "Governing Law": "No governing law clause detected — jurisdiction ambiguity may cause disputes.",
779
- "Termination for Convenience": "No termination clause detected — exit terms are unclear.",
780
- "Limitation of liability": "No liability limitation detected — exposure may be unlimited.",
 
 
 
 
 
 
 
 
 
 
 
 
 
781
  }
782
- for cc, explanation in critical_clauses.items():
783
- if cc not in labels_found:
 
784
  contradictions.append({
785
  "type": "MISSING",
786
- "explanation": explanation,
787
  "severity": "MEDIUM",
788
- "clauses": [cc],
789
  "source": "structural",
790
  })
791
 
@@ -847,13 +922,21 @@ def analyze_contract(text):
847
  contradictions = detect_contradictions(clause_results, text)
848
  risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses))
849
  obligations = extract_obligations(text)
 
850
  compliance = check_compliance(text)
 
 
 
 
 
851
  result = {
852
  "metadata": {
853
  "analysis_date": datetime.now().isoformat(),
854
  "total_clauses": len(clauses),
855
- "flagged_clauses": len(set(cr["text"] for cr in clause_results)),
 
856
  "model": get_model_status_text(),
 
857
  },
858
  "risk": {
859
  "score": risk,
@@ -1119,11 +1202,11 @@ def process_upload(file):
1119
  def run_analysis(text):
1120
  if not text or len(text.strip()) < 50:
1121
  err_html = '<p style="color:#dc2626;padding:16px;">Document too short (minimum 50 characters)</p>'
1122
- return [err_html] * 7 + [None, None, ""]
1123
  result, error = analyze_contract(text)
1124
  if error:
1125
  err_html = f'<p style="color:#dc2626;padding:16px;">{error}</p>'
1126
- return [err_html] * 7 + [None, None, error]
1127
 
1128
  # FIXED: per-session temp files
1129
  session_id = uuid.uuid4().hex[:8]
@@ -1136,6 +1219,10 @@ def run_analysis(text):
1136
  with open(csv_path, "w") as f:
1137
  f.write(csv_content)
1138
 
 
 
 
 
1139
  return [
1140
  render_summary(result),
1141
  render_clause_cards(result),
@@ -1144,13 +1231,15 @@ def run_analysis(text):
1144
  render_document_viewer(result),
1145
  render_obligations_html(result.get("obligations", [])),
1146
  render_compliance_html(result.get("compliance", {})),
 
1147
  json_path,
1148
  csv_path,
1149
  "Analysis complete",
 
1150
  ]
1151
 
1152
  def do_clear():
1153
- return [""] * 7 + [None, None, ""]
1154
 
1155
  # ── Example contracts ──
1156
  SPOTIFY_TOS = """By using the Spotify Service, you agree to be bound by these Terms of Use.
@@ -1234,17 +1323,22 @@ with gr.Blocks(
1234
  """
1235
  ) as demo:
1236
 
 
 
 
 
 
1237
  gr.HTML("""
1238
  <div style="display:flex;align-items:center;justify-content:space-between;padding:12px 0;border-bottom:2px solid #e5e7eb;margin-bottom:16px;">
1239
  <div>
1240
  <h1 style="font-size:24px;font-weight:700;margin:0;color:#1f2937;">🛡️ ClauseGuard</h1>
1241
- <p style="font-size:13px;color:#6b7280;margin:4px 0 0 0;">AI-Powered Legal Contract Analysis · 41 Clause Categories · Risk Scoring · ML NER · NLI Contradictions · Compliance · Obligations</p>
1242
  </div>
1243
- <div style="font-size:12px;color:#9ca3af;">v3.0 · Precision Legal AI</div>
1244
  </div>
1245
  """)
1246
 
1247
- # ── Main Tabs: Analysis vs Comparison ──
1248
  with gr.Tabs():
1249
 
1250
  # ═══════ TAB 1: Single Contract Analysis ═══════
@@ -1261,7 +1355,7 @@ with gr.Blocks(
1261
  with gr.Column(scale=3):
1262
  text_input = gr.Textbox(
1263
  label="📄 Contract Text",
1264
- placeholder="Paste contract text here, or upload a file above...",
1265
  lines=14,
1266
  max_lines=40,
1267
  show_copy_button=True,
@@ -1304,6 +1398,8 @@ with gr.Blocks(
1304
  obligations_html = gr.HTML(label="Obligation Tracker")
1305
  with gr.Tab("⚖️ Compliance"):
1306
  compliance_html = gr.HTML(label="Compliance Checker")
 
 
1307
 
1308
  # ═══════ TAB 2: Contract Comparison ═══════
1309
  with gr.Tab("🔀 Compare Contracts"):
@@ -1352,6 +1448,53 @@ with gr.Blocks(
1352
  with gr.Column(scale=2):
1353
  comp_json = gr.JSON(label="Raw Comparison Data")
1354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1355
  # ── Events ──
1356
  def _load_file(file):
1357
  text, err = parse_document(file) if file else ("", "No file")
@@ -1359,23 +1502,41 @@ with gr.Blocks(
1359
  return "", err
1360
  return text, "Loaded successfully" if not err else err
1361
 
 
 
 
 
 
 
 
 
 
 
 
 
1362
  load_btn.click(_load_file, inputs=[file_input], outputs=[text_input, load_status])
1363
  comp_load_a.click(_load_file, inputs=[comp_file_a], outputs=[comp_text_a, comp_status_a])
1364
  comp_load_b.click(_load_file, inputs=[comp_file_b], outputs=[comp_text_b, comp_status_b])
1365
 
1366
  scan_btn.click(
1367
- run_analysis,
1368
  inputs=[text_input],
1369
- outputs=[summary_html, clauses_html, entities_html, nli_html,
1370
- doc_html, obligations_html, compliance_html,
1371
- json_file, csv_file, status_msg]
 
 
 
1372
  )
1373
 
1374
  clear_btn.click(
1375
- do_clear,
1376
- outputs=[summary_html, clauses_html, entities_html, nli_html,
1377
- doc_html, obligations_html, compliance_html,
1378
- json_file, csv_file, status_msg]
 
 
 
1379
  )
1380
 
1381
  comp_btn.click(
@@ -1391,6 +1552,8 @@ with gr.Blocks(
1391
  · Model: <a href="https://huggingface.co/Mokshith31/legalbert-contract-clause-classification" style="color:#6b7280;">Legal-BERT + CUAD (41 classes)</a>
1392
  · NER: <a href="https://huggingface.co/matterstack/legal-bert-ner" style="color:#6b7280;">Legal-BERT NER</a>
1393
  · NLI: <a href="https://huggingface.co/cross-encoder/nli-deberta-v3-base" style="color:#6b7280;">DeBERTa-v3 NLI</a>
 
 
1394
  · Dataset: <a href="https://huggingface.co/datasets/theatticusproject/cuad-qa" style="color:#6b7280;">CUAD</a>
1395
  · <a href="https://huggingface.co/spaces/gaurv007/ClauseGuard" style="color:#6b7280;">ClauseGuard Space</a>
1396
  </p>
 
1
  """
2
+ ClauseGuard — World's Best Legal Contract Analysis Tool (v4.0)
3
  ═══════════════════════════════════════════════════════════════
4
+ New in v4.0:
5
+ • OCR support for scanned PDFs (docTR engine with smart native/scanned routing)
6
+ • Contract Q&A Chatbot (RAG: embedding retrieval + HF Inference API streaming)
7
+ • Clause Redlining (3-tier: template lookup + RAG + LLM refinement)
8
+
9
+ Carried from v3.0:
10
  • Fixed CUAD label mapping (added missing index 6: "Notice Period to Terminate Renewal")
11
  • Switched from softmax → sigmoid for proper multi-label classification
12
  • Per-class optimized thresholds instead of flat 0.15
 
26
  (LoRA adapter on nlpaueb/legal-bert-base-uncased, 41 CUAD classes)
27
  • Legal NER: matterstack/legal-bert-ner (token classification)
28
  • NLI: cross-encoder/nli-deberta-v3-base (contradiction detection)
29
+ • Embeddings: sentence-transformers/all-MiniLM-L6-v2 (RAG retrieval)
30
+ • OCR: docTR fast_base + crnn_vgg16_bn (scanned PDF extraction)
31
+ • LLM: Qwen/Qwen2.5-7B-Instruct via HF Inference API (chatbot + redlining)
32
  """
33
 
34
  import os
 
79
  from compare import compare_contracts, render_comparison_html
80
  from obligations import extract_obligations, render_obligations_html
81
  from compliance import check_compliance, render_compliance_html
82
+ from ocr_engine import parse_pdf_smart, get_ocr_status
83
+ from chatbot import index_contract, chat_respond, get_chatbot_status
84
+ from redlining import generate_redlines, render_redlines_html
85
 
86
  # ═══════════════════════════════════════════════════════════════════════
87
  # 1. CONFIGURATION — FIXED label mapping (41 labels, index 6 restored)
 
346
  # ═══════════════════════════════════════════════════════════════════════
347
 
348
  def parse_pdf(file_path):
349
+ """Smart PDF parser: native text extraction with OCR fallback for scanned PDFs."""
350
+ text, error, method = parse_pdf_smart(file_path)
351
+ if text:
352
+ if method == "ocr":
353
+ print(f"[ClauseGuard] PDF extracted via OCR ({len(text)} chars)")
354
+ return text, None
355
+ if error:
356
+ return None, error
357
+ return None, "Could not extract text from PDF. Try uploading a clearer scan or digital PDF."
 
 
 
 
 
358
 
359
  def parse_docx(file_path):
360
  if not _HAS_DOCX:
 
384
  return None, f"Unsupported file type: {ext}"
385
 
386
  # ═══════════════════════════════════════════════════════════════════════
387
+ # 4. DETERMINISTIC CLAUSE SPLITTING (Fix 1 from bug report)
388
  # ═══════════════════════════════════════════════════════════════════════
389
 
390
+ # Document-level chunk cache: same text always produces same chunks
391
+ _chunk_cache = {}
392
+
393
  def split_clauses(text):
394
+ """Deterministic, structure-aware clause splitting.
395
+ Fix 1: Same input ALWAYS produces same output. Normalized text is hashed
396
+ and cached so repeated runs on identical documents are identical."""
397
+ # Normalize whitespace before hashing for determinism
398
+ normalized = re.sub(r'\s+', ' ', text.strip())
399
+ text_hash = hashlib.sha256(normalized.encode()).hexdigest()
400
+ if text_hash in _chunk_cache:
401
+ return _chunk_cache[text_hash]
402
+
403
  text = re.sub(r'\n{3,}', '\n\n', text.strip())
404
 
405
  # First try to detect numbered sections (1., 2., 3.1, (a), etc.)
 
443
  preamble = text[:positions[0]].strip()
444
  if len(preamble) > 30:
445
  clauses.insert(0, preamble)
446
+ result = clauses if clauses else _fallback_split(text)
447
+ _chunk_cache[text_hash] = result
448
+ return result
449
  else:
450
+ result = _fallback_split(text)
451
+ _chunk_cache[text_hash] = result
452
+ return result
453
 
454
  def _fallback_split(text):
455
  """Fallback: split on paragraph breaks and sentence boundaries."""
 
483
 
484
  # ═══════════════════════════════════════════════════════════════════════
485
  # 5. CLAUSE DETECTION — FIXED: sigmoid + per-class thresholds + caching
486
+ # Fix 3: Strip section headings before classification
487
+ # Fix 6: Label guardrails for high-confidence false positives
488
  # ═══════════════════════════════════════════════════════════════════════
489
 
490
+ # Fix 3: Section heading pattern — strip before classifying
491
+ _HEADING_RE = re.compile(r'^\d+(?:\.\d+)*\s+[A-Z][A-Z\s&,/]+$', re.MULTILINE)
492
+
493
+ def _strip_heading(text):
494
+ """Remove leading section headings that confuse the classifier."""
495
+ lines = text.split('\n')
496
+ if lines and _HEADING_RE.match(lines[0].strip()):
497
+ stripped = '\n'.join(lines[1:]).strip()
498
+ return stripped if len(stripped) > 20 else text
499
+ return text
500
+
501
+ # Fix 6: Label guardrails — keyword validation for high-confidence labels
502
+ _LABEL_GUARDRAILS = {
503
+ "Liquidated Damages": re.compile(
504
+ r'liquidated|pre-?determined.{0,10}damage|agreed.{0,10}sum|penalty clause|stipulated.{0,10}damage',
505
+ re.IGNORECASE
506
+ ),
507
+ "Uncapped Liability": re.compile(
508
+ r'uncapped|unlimited.{0,10}liabilit|no.{0,10}(limit|cap).{0,10}liabilit',
509
+ re.IGNORECASE
510
+ ),
511
+ }
512
+
513
+ def _apply_guardrails(label, text, confidence):
514
+ """Fix 6: If label has a guardrail and text lacks required keywords, demote."""
515
+ guard = _LABEL_GUARDRAILS.get(label)
516
+ if guard and not guard.search(text):
517
+ return "Other", confidence * 0.3 # demote to Other with reduced confidence
518
+ return label, confidence
519
+
520
  def _text_hash(text):
521
  return hashlib.md5(text.encode()).hexdigest()
522
 
 
527
  if cuad_model is None or cuad_tokenizer is None:
528
  return _classify_regex(clause_text)
529
 
530
+ # Fix 3: Strip section headings before classification
531
+ clean_text = _strip_heading(clause_text)
532
+
533
  # Check cache
534
+ h = _text_hash(clean_text[:512])
535
  if h in _prediction_cache:
536
  return _prediction_cache[h]
537
 
538
  try:
539
  inputs = cuad_tokenizer(
540
+ clean_text,
541
  return_tensors="pt",
542
  truncation=True,
543
  max_length=256,
 
554
  threshold = _CUAD_THRESHOLDS.get(i, 0.40)
555
  if float(prob) > threshold and i < len(CUAD_LABELS):
556
  label = CUAD_LABELS[i]
557
+ conf = float(prob)
558
+ # Fix 6: Apply guardrails — reject high-confidence false positives
559
+ label, conf = _apply_guardrails(label, clause_text, conf)
560
+ if label == "Other" and conf < 0.3:
561
+ continue # Skip demoted labels
562
  risk = RISK_MAP.get(label, "LOW")
563
  results.append({
564
  "label": label,
565
+ "confidence": round(conf, 3),
566
  "risk": risk,
567
  "description": DESC_MAP.get(label, label),
568
  "source": "ml",
 
834
  "source": "heuristic",
835
  })
836
 
837
+ # ── 2. Missing critical clauses (Fix 4: check raw_text, not labels) ──
838
+ _REQUIRED_CLAUSE_PATTERNS = {
839
+ "Governing Law": re.compile(
840
+ r'govern(?:ed|ing).{0,15}law|applicable.{0,10}law|laws?\s+of\s+the\s+state',
841
+ re.IGNORECASE
842
+ ),
843
+ "Limitation of liability": re.compile(
844
+ r'limitation.{0,10}liabilit|cap.{0,10}liabilit|liabilit.{0,10}shall\s+not\s+exceed|in\s+no\s+event.{0,20}liable',
845
+ re.IGNORECASE
846
+ ),
847
+ "Arbitration": re.compile(
848
+ r'arbitrat|AAA|JAMS|binding.{0,10}dispute',
849
+ re.IGNORECASE
850
+ ),
851
+ "Termination": re.compile(
852
+ r'terminat(?:e|ion|ed)|cancel(?:lation)?',
853
+ re.IGNORECASE
854
+ ),
855
  }
856
+ for clause_name, pattern in _REQUIRED_CLAUSE_PATTERNS.items():
857
+ # Check raw_text directly — it's stable and deterministic
858
+ if not pattern.search(raw_text):
859
  contradictions.append({
860
  "type": "MISSING",
861
+ "explanation": f"No '{clause_name}' clause detected in the document.",
862
  "severity": "MEDIUM",
863
+ "clauses": [clause_name],
864
  "source": "structural",
865
  })
866
 
 
922
  contradictions = detect_contradictions(clause_results, text)
923
  risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses))
924
  obligations = extract_obligations(text)
925
+ # Fix 5: Compliance runs against full raw_text (already done in compliance.py)
926
  compliance = check_compliance(text)
927
+
928
+ # Fix 2: Compute flagged_clauses AFTER all processing is complete
929
+ flagged_clause_count = len(clause_results)
930
+ unique_flagged_texts = len(set(cr["text"] for cr in clause_results))
931
+
932
  result = {
933
  "metadata": {
934
  "analysis_date": datetime.now().isoformat(),
935
  "total_clauses": len(clauses),
936
+ "flagged_clauses": flagged_clause_count,
937
+ "unique_flagged": unique_flagged_texts,
938
  "model": get_model_status_text(),
939
+ "text_hash": hashlib.sha256(re.sub(r'\s+', ' ', text.strip()).encode()).hexdigest()[:16],
940
  },
941
  "risk": {
942
  "score": risk,
 
1202
  def run_analysis(text):
1203
  if not text or len(text.strip()) < 50:
1204
  err_html = '<p style="color:#dc2626;padding:16px;">Document too short (minimum 50 characters)</p>'
1205
+ return [err_html] * 8 + [None, None, "", None]
1206
  result, error = analyze_contract(text)
1207
  if error:
1208
  err_html = f'<p style="color:#dc2626;padding:16px;">{error}</p>'
1209
+ return [err_html] * 8 + [None, None, error, None]
1210
 
1211
  # FIXED: per-session temp files
1212
  session_id = uuid.uuid4().hex[:8]
 
1219
  with open(csv_path, "w") as f:
1220
  f.write(csv_content)
1221
 
1222
+ # Generate redline suggestions (Tier 1 template + Tier 3 LLM for critical/high)
1223
+ redlines = generate_redlines(result, use_llm=True)
1224
+ redlines_html = render_redlines_html(redlines)
1225
+
1226
  return [
1227
  render_summary(result),
1228
  render_clause_cards(result),
 
1231
  render_document_viewer(result),
1232
  render_obligations_html(result.get("obligations", [])),
1233
  render_compliance_html(result.get("compliance", {})),
1234
+ redlines_html,
1235
  json_path,
1236
  csv_path,
1237
  "Analysis complete",
1238
+ result, # Store analysis result for chatbot
1239
  ]
1240
 
1241
  def do_clear():
1242
+ return [""] * 8 + [None, None, "", None]
1243
 
1244
  # ── Example contracts ──
1245
  SPOTIFY_TOS = """By using the Spotify Service, you agree to be bound by these Terms of Use.
 
1323
  """
1324
  ) as demo:
1325
 
1326
+ # ── Shared State (for chatbot RAG) ──────────────────────────────
1327
+ analysis_state = gr.State(None) # Full analysis result dict
1328
+ chunks_state = gr.State([]) # Contract text chunks for RAG
1329
+ embeddings_state = gr.State(None) # Chunk embeddings (numpy array)
1330
+
1331
  gr.HTML("""
1332
  <div style="display:flex;align-items:center;justify-content:space-between;padding:12px 0;border-bottom:2px solid #e5e7eb;margin-bottom:16px;">
1333
  <div>
1334
  <h1 style="font-size:24px;font-weight:700;margin:0;color:#1f2937;">🛡️ ClauseGuard</h1>
1335
+ <p style="font-size:13px;color:#6b7280;margin:4px 0 0 0;">AI-Powered Legal Contract Analysis · 41 Clause Categories · Risk Scoring · ML NER · NLI Contradictions · Compliance · Obligations · <strong>Q&A Chatbot</strong> · <strong>Clause Redlining</strong> · <strong>OCR</strong></p>
1336
  </div>
1337
+ <div style="font-size:12px;color:#9ca3af;">v4.0 · Precision Legal AI</div>
1338
  </div>
1339
  """)
1340
 
1341
+ # ── Main Tabs: Analysis vs Comparison vs Chatbot ──
1342
  with gr.Tabs():
1343
 
1344
  # ═══════ TAB 1: Single Contract Analysis ═══════
 
1355
  with gr.Column(scale=3):
1356
  text_input = gr.Textbox(
1357
  label="📄 Contract Text",
1358
+ placeholder="Paste contract text here, or upload a file above...\n\n💡 Scanned PDFs are automatically processed with OCR.",
1359
  lines=14,
1360
  max_lines=40,
1361
  show_copy_button=True,
 
1398
  obligations_html = gr.HTML(label="Obligation Tracker")
1399
  with gr.Tab("⚖️ Compliance"):
1400
  compliance_html = gr.HTML(label="Compliance Checker")
1401
+ with gr.Tab("✏️ Redlining"):
1402
+ redlining_html = gr.HTML(label="Clause Redlining Suggestions")
1403
 
1404
  # ═══════ TAB 2: Contract Comparison ═══════
1405
  with gr.Tab("🔀 Compare Contracts"):
 
1448
  with gr.Column(scale=2):
1449
  comp_json = gr.JSON(label="Raw Comparison Data")
1450
 
1451
+ # ═══════ TAB 3: Contract Q&A Chatbot ═══════
1452
+ with gr.Tab("💬 Contract Q&A"):
1453
+ gr.HTML("""
1454
+ <div style="padding:12px 16px;background:linear-gradient(135deg,#eff6ff,#faf5ff);border-radius:10px;margin-bottom:12px;border:1px solid #e5e7eb;">
1455
+ <div style="display:flex;align-items:center;gap:8px;margin-bottom:6px;">
1456
+ <span style="font-size:20px;">💬</span>
1457
+ <h3 style="margin:0;font-size:16px;color:#1f2937;">Contract Q&A Chatbot</h3>
1458
+ </div>
1459
+ <p style="font-size:12px;color:#6b7280;margin:0;line-height:1.5;">
1460
+ Ask questions about your analyzed contract. The chatbot uses <strong>RAG</strong> (Retrieval-Augmented Generation)
1461
+ to find relevant clauses and generate accurate answers grounded in your contract text.
1462
+ <br>
1463
+ <strong>Step 1:</strong> Analyze a contract in the "📄 Single Contract Analysis" tab.
1464
+ <strong>Step 2:</strong> Come here and ask questions!
1465
+ </p>
1466
+ </div>
1467
+ """)
1468
+
1469
+ chatbot_index_status = gr.Textbox(
1470
+ label="📡 Chatbot Index Status",
1471
+ interactive=False,
1472
+ lines=1,
1473
+ value="⏳ No contract indexed yet — analyze a contract first",
1474
+ )
1475
+
1476
+ def _chatbot_fn(message, history, chunks, embeddings, analysis):
1477
+ """Wrapper for ChatInterface fn signature."""
1478
+ yield from chat_respond(message, history, chunks, embeddings, analysis)
1479
+
1480
+ gr.ChatInterface(
1481
+ fn=_chatbot_fn,
1482
+ type="messages",
1483
+ additional_inputs=[chunks_state, embeddings_state, analysis_state],
1484
+ examples=[
1485
+ ["What are the main risks in this contract?"],
1486
+ ["Who are the parties involved?"],
1487
+ ["What happens if the contract is terminated?"],
1488
+ ["Are there any liability limitations?"],
1489
+ ["What are my obligations under this contract?"],
1490
+ ["Is there an arbitration clause?"],
1491
+ ["What is the governing law?"],
1492
+ ["Summarize the key terms in plain language."],
1493
+ ],
1494
+ title="",
1495
+ description="",
1496
+ )
1497
+
1498
  # ── Events ──
1499
  def _load_file(file):
1500
  text, err = parse_document(file) if file else ("", "No file")
 
1502
  return "", err
1503
  return text, "Loaded successfully" if not err else err
1504
 
1505
+ def _analysis_and_index(text):
1506
+ """Run analysis AND index for chatbot in one call."""
1507
+ # Run the standard analysis
1508
+ analysis_outputs = run_analysis(text)
1509
+
1510
+ # Index for chatbot (uses the raw text)
1511
+ chunks, embeddings, index_status = index_contract(text)
1512
+
1513
+ # analysis_outputs has 12 items: 8 HTML + json_path + csv_path + status + result
1514
+ # We need to add: chunks_state, embeddings_state, chatbot_index_status
1515
+ return analysis_outputs + [chunks, embeddings, index_status]
1516
+
1517
  load_btn.click(_load_file, inputs=[file_input], outputs=[text_input, load_status])
1518
  comp_load_a.click(_load_file, inputs=[comp_file_a], outputs=[comp_text_a, comp_status_a])
1519
  comp_load_b.click(_load_file, inputs=[comp_file_b], outputs=[comp_text_b, comp_status_b])
1520
 
1521
  scan_btn.click(
1522
+ _analysis_and_index,
1523
  inputs=[text_input],
1524
+ outputs=[
1525
+ summary_html, clauses_html, entities_html, nli_html,
1526
+ doc_html, obligations_html, compliance_html, redlining_html,
1527
+ json_file, csv_file, status_msg, analysis_state,
1528
+ chunks_state, embeddings_state, chatbot_index_status,
1529
+ ]
1530
  )
1531
 
1532
  clear_btn.click(
1533
+ lambda: [""] * 8 + [None, None, "", None, [], None, "⏳ No contract indexed"],
1534
+ outputs=[
1535
+ summary_html, clauses_html, entities_html, nli_html,
1536
+ doc_html, obligations_html, compliance_html, redlining_html,
1537
+ json_file, csv_file, status_msg, analysis_state,
1538
+ chunks_state, embeddings_state, chatbot_index_status,
1539
+ ]
1540
  )
1541
 
1542
  comp_btn.click(
 
1552
  · Model: <a href="https://huggingface.co/Mokshith31/legalbert-contract-clause-classification" style="color:#6b7280;">Legal-BERT + CUAD (41 classes)</a>
1553
  · NER: <a href="https://huggingface.co/matterstack/legal-bert-ner" style="color:#6b7280;">Legal-BERT NER</a>
1554
  · NLI: <a href="https://huggingface.co/cross-encoder/nli-deberta-v3-base" style="color:#6b7280;">DeBERTa-v3 NLI</a>
1555
+ · LLM: <a href="https://huggingface.co/Qwen/Qwen2.5-7B-Instruct" style="color:#6b7280;">Qwen2.5-7B</a>
1556
+ · OCR: <a href="https://github.com/mindee/doctr" style="color:#6b7280;">docTR</a>
1557
  · Dataset: <a href="https://huggingface.co/datasets/theatticusproject/cuad-qa" style="color:#6b7280;">CUAD</a>
1558
  · <a href="https://huggingface.co/spaces/gaurv007/ClauseGuard" style="color:#6b7280;">ClauseGuard Space</a>
1559
  </p>
chatbot.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ClauseGuard — Contract Q&A Chatbot (RAG) v1.0
3
+ ═══════════════════════════════════════════════
4
+ Architecture:
5
+ User asks question about their contract
6
+
7
+ [1] Embed question with sentence-transformers (all-MiniLM-L6-v2)
8
+
9
+ [2] Retrieve top-5 most relevant chunks from contract
10
+
11
+ [3] Build prompt:
12
+ - System: ClauseGuard analysis results (clauses, entities, risk scores)
13
+ - Context: Retrieved contract chunks (≤2.5K tokens)
14
+ - User question
15
+
16
+ [4] Stream response from LLM via HF Inference API
17
+
18
+ Key design:
19
+ • Analyzed data (clauses, entities, risk scores) → system prompt
20
+ • Raw contract text → RAG retrieval
21
+ • This gives the model both structured analysis AND verbatim evidence
22
+ """
23
+
24
+ import os
25
+ import re
26
+ import numpy as np
27
+
28
+ # ── Embedding model (soft-fail) ─────────────────────────────────────
29
+ _HAS_EMBEDDER = False
30
+ _embedder = None
31
+
32
+ try:
33
+ from sentence_transformers import SentenceTransformer
34
+ _HAS_EMBEDDER = True
35
+ except ImportError:
36
+ pass
37
+
38
+ # ── HF Inference Client (soft-fail) ─────────────────────────────────
39
+ _HAS_INFERENCE = False
40
+ _llm_client = None
41
+
42
+ try:
43
+ from huggingface_hub import InferenceClient
44
+ _HAS_INFERENCE = True
45
+ except ImportError:
46
+ pass
47
+
48
+ # ═══════════════════════════════════════════════════════════════════════
49
+ # MODEL LOADING
50
+ # ═══════════════════════════════════════════════════════════════════════
51
+
52
+ _chatbot_status = {"embedder": "not_loaded", "llm": "not_loaded"}
53
+
54
+ def _load_embedder():
55
+ """Load sentence-transformers embedding model (lazy)."""
56
+ global _embedder, _chatbot_status
57
+ if _embedder is not None:
58
+ return _embedder
59
+ if not _HAS_EMBEDDER:
60
+ _chatbot_status["embedder"] = "unavailable"
61
+ return None
62
+ try:
63
+ print("[ClauseGuard Chat] Loading embedding model: all-MiniLM-L6-v2...")
64
+ _embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
65
+ _chatbot_status["embedder"] = "loaded"
66
+ print("[ClauseGuard Chat] Embedding model loaded")
67
+ return _embedder
68
+ except Exception as e:
69
+ _chatbot_status["embedder"] = f"failed: {e}"
70
+ print(f"[ClauseGuard Chat] Embedder load failed: {e}")
71
+ return None
72
+
73
+
74
+ def _get_llm_client():
75
+ """Get or create HF Inference Client (lazy)."""
76
+ global _llm_client, _chatbot_status
77
+ if _llm_client is not None:
78
+ return _llm_client
79
+ if not _HAS_INFERENCE:
80
+ _chatbot_status["llm"] = "unavailable"
81
+ return None
82
+ try:
83
+ token = os.environ.get("HF_TOKEN", "")
84
+ _llm_client = InferenceClient(
85
+ provider="hf-inference",
86
+ api_key=token if token else None,
87
+ )
88
+ _chatbot_status["llm"] = "loaded"
89
+ print("[ClauseGuard Chat] HF Inference Client initialized")
90
+ return _llm_client
91
+ except Exception as e:
92
+ _chatbot_status["llm"] = f"failed: {e}"
93
+ print(f"[ClauseGuard Chat] LLM client init failed: {e}")
94
+ return None
95
+
96
+
97
+ def get_chatbot_status():
98
+ """Return human-readable chatbot status."""
99
+ parts = []
100
+ for name, status in _chatbot_status.items():
101
+ icon = "✅" if status == "loaded" else "⚠️" if "failed" in status else "❌"
102
+ label = {"embedder": "Embeddings", "llm": "LLM API"}[name]
103
+ parts.append(f"{icon} {label}: {status}")
104
+ return " · ".join(parts)
105
+
106
+
107
+ # ═══════════════════════════════════════════════════════════════════════
108
+ # TEXT CHUNKING (sentence-preserving, ~300 tokens, no overlap)
109
+ # ═══════════════════════════════════════════════════════════════════════
110
+
111
+ def chunk_contract_text(text, target_chunk_size=300, min_chunk_size=50):
112
+ """
113
+ Split contract text into chunks for RAG retrieval.
114
+ Sentence-preserving, ~300 tokens per chunk, 0% overlap.
115
+ Research (arxiv 2601.14123): overlap adds cost with zero benefit.
116
+ """
117
+ if not text:
118
+ return []
119
+
120
+ # First split on paragraph boundaries
121
+ paragraphs = re.split(r'\n\n+', text)
122
+ chunks = []
123
+ current_chunk = ""
124
+
125
+ for para in paragraphs:
126
+ para = para.strip()
127
+ if not para:
128
+ continue
129
+
130
+ # Estimate word count (rough token proxy)
131
+ words_current = len(current_chunk.split())
132
+ words_para = len(para.split())
133
+
134
+ if words_current + words_para <= target_chunk_size:
135
+ current_chunk += ("\n\n" + para if current_chunk else para)
136
+ else:
137
+ # Current chunk is full enough — save it
138
+ if words_current >= min_chunk_size:
139
+ chunks.append(current_chunk.strip())
140
+ current_chunk = para
141
+ else:
142
+ # Current chunk too small — need to split the paragraph into sentences
143
+ sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', para)
144
+ for sent in sentences:
145
+ words_current = len(current_chunk.split())
146
+ words_sent = len(sent.split())
147
+ if words_current + words_sent <= target_chunk_size:
148
+ current_chunk += (" " + sent if current_chunk else sent)
149
+ else:
150
+ if words_current >= min_chunk_size:
151
+ chunks.append(current_chunk.strip())
152
+ current_chunk = sent
153
+
154
+ # Don't forget the last chunk
155
+ if current_chunk.strip() and len(current_chunk.split()) >= min_chunk_size:
156
+ chunks.append(current_chunk.strip())
157
+
158
+ return chunks
159
+
160
+
161
+ # ═══════════════════════════════════════════════════════════════════════
162
+ # EMBEDDING & RETRIEVAL
163
+ # ═══════════════════════════════════════════════════════════════════════
164
+
165
+ def build_embeddings(chunks):
166
+ """
167
+ Embed chunks using sentence-transformers.
168
+ Returns numpy array of shape (N, 384) or None if embedder unavailable.
169
+ """
170
+ embedder = _load_embedder()
171
+ if embedder is None or not chunks:
172
+ return None
173
+ try:
174
+ embeddings = embedder.encode(
175
+ chunks,
176
+ normalize_embeddings=True,
177
+ batch_size=32,
178
+ show_progress_bar=False,
179
+ )
180
+ return embeddings # numpy array (N, 384)
181
+ except Exception as e:
182
+ print(f"[ClauseGuard Chat] Embedding error: {e}")
183
+ return None
184
+
185
+
186
+ def retrieve_chunks(query, chunks, embeddings, top_k=5):
187
+ """
188
+ Retrieve top-k most relevant chunks for a query.
189
+ Uses cosine similarity (embeddings are L2-normalized → dot product = cosine).
190
+ Context budget: top-5 chunks, ≤2.5K tokens.
191
+ """
192
+ embedder = _load_embedder()
193
+ if embedder is None or embeddings is None or not chunks:
194
+ return []
195
+
196
+ try:
197
+ q_emb = embedder.encode([query], normalize_embeddings=True)
198
+ scores = (q_emb @ embeddings.T)[0]
199
+ top_indices = np.argsort(scores)[::-1][:top_k]
200
+
201
+ results = []
202
+ total_words = 0
203
+ max_words = 600 # ~2.5K tokens budget
204
+
205
+ for idx in top_indices:
206
+ chunk = chunks[idx]
207
+ chunk_words = len(chunk.split())
208
+ if total_words + chunk_words > max_words and results:
209
+ break
210
+ results.append({
211
+ "text": chunk,
212
+ "score": float(scores[idx]),
213
+ "index": int(idx),
214
+ })
215
+ total_words += chunk_words
216
+
217
+ return results
218
+ except Exception as e:
219
+ print(f"[ClauseGuard Chat] Retrieval error: {e}")
220
+ return []
221
+
222
+
223
+ # ═══════════════════════════════════════════════════════════════════════
224
+ # SYSTEM PROMPT BUILDER
225
+ # ═══════════════════════════════════════════════════════════════════════
226
+
227
+ def _build_system_prompt(analysis_result, retrieved_chunks):
228
+ """
229
+ Build the system prompt with:
230
+ 1. ClauseGuard analysis results (clauses, entities, risk scores) — NOT through RAG
231
+ 2. Retrieved contract chunks — through RAG
232
+ """
233
+ parts = []
234
+
235
+ parts.append("""You are ClauseGuard AI, a legal contract analysis assistant. You help users understand their contracts by answering questions based on the contract text and analysis results.
236
+
237
+ RULES:
238
+ - Answer ONLY based on the provided contract text and analysis. Never make up information.
239
+ - If the answer isn't in the provided context, say "I don't see that information in the analyzed contract."
240
+ - Cite specific clauses or sections when possible.
241
+ - Be concise but thorough. Use plain language, not legal jargon.
242
+ - Always end with: "⚠️ This is AI analysis, not legal advice. Consult an attorney for legal decisions."
243
+ """)
244
+
245
+ # Add analysis summary if available
246
+ if analysis_result:
247
+ risk = analysis_result.get("risk", {})
248
+ parts.append(f"""
249
+ ═��═ CONTRACT ANALYSIS SUMMARY ═══
250
+ Risk Score: {risk.get('score', 'N/A')}/100 (Grade {risk.get('grade', 'N/A')})
251
+ Risk Breakdown: {risk.get('breakdown', {})}
252
+ Total Clauses Analyzed: {analysis_result.get('metadata', {}).get('total_clauses', 'N/A')}
253
+ Flagged Clauses: {analysis_result.get('metadata', {}).get('flagged_clauses', 'N/A')}
254
+ """)
255
+
256
+ # Add detected clauses summary
257
+ clauses = analysis_result.get("clauses", [])
258
+ if clauses:
259
+ clause_summary = []
260
+ seen = set()
261
+ for c in clauses:
262
+ key = c["label"]
263
+ if key not in seen:
264
+ seen.add(key)
265
+ risk_level = c.get("risk", "LOW")
266
+ clause_summary.append(f" • [{risk_level}] {key}: {c.get('description', '')}")
267
+ parts.append("═══ DETECTED CLAUSES ═══\n" + "\n".join(clause_summary[:20]))
268
+
269
+ # Add entities summary
270
+ entities = analysis_result.get("entities", [])
271
+ if entities:
272
+ entity_summary = []
273
+ seen = set()
274
+ for e in entities:
275
+ key = f"{e['type']}: {e['text']}"
276
+ if key not in seen and len(seen) < 15:
277
+ seen.add(key)
278
+ entity_summary.append(f" • {e['type']}: {e['text']}")
279
+ parts.append("═══ EXTRACTED ENTITIES ═══\n" + "\n".join(entity_summary))
280
+
281
+ # Add contradictions
282
+ contradictions = analysis_result.get("contradictions", [])
283
+ if contradictions:
284
+ contra_summary = []
285
+ for c in contradictions:
286
+ contra_summary.append(f" • [{c['type']}] {c['explanation']}")
287
+ parts.append("═══ CONTRADICTIONS / ISSUES ═══\n" + "\n".join(contra_summary))
288
+
289
+ # Add retrieved contract text
290
+ if retrieved_chunks:
291
+ context_text = "\n---\n".join(c["text"] for c in retrieved_chunks)
292
+ parts.append(f"""
293
+ ═══ RELEVANT CONTRACT TEXT (Retrieved) ═══
294
+ {context_text}
295
+ """)
296
+
297
+ return "\n\n".join(parts)
298
+
299
+
300
+ # ═══════════════════════════════════════════════════════════════════════
301
+ # CHAT RESPONSE (Streaming)
302
+ # ═══════════════════════════════════════════════════════════════════════
303
+
304
+ # LLM model to use
305
+ _LLM_MODEL = "Qwen/Qwen2.5-7B-Instruct"
306
+
307
+ def chat_respond(message, history, chunks, embeddings, analysis_result):
308
+ """
309
+ RAG chatbot response function for gr.ChatInterface.
310
+
311
+ Args:
312
+ message: User's question (str)
313
+ history: Chat history (list of dicts with role/content)
314
+ chunks: Contract text chunks (list of str)
315
+ embeddings: Chunk embeddings (numpy array or None)
316
+ analysis_result: Full analysis result dict (or None)
317
+
318
+ Yields:
319
+ Partial response string (streaming)
320
+ """
321
+ # Validate inputs
322
+ if not chunks or embeddings is None:
323
+ yield ("⚠️ No contract loaded yet. Please upload and analyze a contract in the "
324
+ "**📄 Single Contract Analysis** tab first, then come back here to ask questions.")
325
+ return
326
+
327
+ if not message or not message.strip():
328
+ yield "Please ask a question about your contract."
329
+ return
330
+
331
+ # Step 1: Retrieve relevant chunks
332
+ retrieved = retrieve_chunks(message, chunks, embeddings, top_k=5)
333
+
334
+ # Step 2: Build system prompt with analysis + retrieved context
335
+ system_prompt = _build_system_prompt(analysis_result, retrieved)
336
+
337
+ # Step 3: Build message history for LLM
338
+ messages = [{"role": "system", "content": system_prompt}]
339
+
340
+ # Add recent history (last 6 turns to stay in context window)
341
+ if history:
342
+ for h in history[-6:]:
343
+ messages.append({"role": h["role"], "content": h["content"]})
344
+
345
+ messages.append({"role": "user", "content": message})
346
+
347
+ # Step 4: Stream response from LLM
348
+ client = _get_llm_client()
349
+ if client is None:
350
+ yield ("⚠️ LLM service unavailable. Please ensure `huggingface_hub` is installed "
351
+ "and `HF_TOKEN` is set.")
352
+ return
353
+
354
+ try:
355
+ stream = client.chat_completion(
356
+ model=_LLM_MODEL,
357
+ messages=messages,
358
+ max_tokens=1024,
359
+ stream=True,
360
+ temperature=0.3, # Low temperature for factual responses
361
+ )
362
+ partial = ""
363
+ for chunk in stream:
364
+ token = chunk.choices[0].delta.content or ""
365
+ partial += token
366
+ yield partial
367
+ except Exception as e:
368
+ error_msg = str(e)
369
+ if "rate limit" in error_msg.lower() or "429" in error_msg:
370
+ yield ("⚠️ Rate limit reached on the free HF Inference API. "
371
+ "Please wait a moment and try again.")
372
+ elif "401" in error_msg or "unauthorized" in error_msg.lower():
373
+ yield ("⚠️ Authentication error. Please set your HF_TOKEN in the Space settings.")
374
+ else:
375
+ yield f"⚠️ Error generating response: {error_msg}\n\nPlease try again."
376
+
377
+
378
+ # ═══════════════════════════════════════════════════════════════════════
379
+ # INDEXING HELPER (combines chunking + embedding)
380
+ # ═══════════════════════════════════════════════════════════════════════
381
+
382
+ def index_contract(text):
383
+ """
384
+ Chunk and embed contract text for RAG retrieval.
385
+
386
+ Returns: (chunks, embeddings, status_message)
387
+ chunks: list of str
388
+ embeddings: numpy array or None
389
+ status_message: str
390
+ """
391
+ if not text or len(text.strip()) < 50:
392
+ return [], None, "⚠️ No contract text to index"
393
+
394
+ chunks = chunk_contract_text(text)
395
+ if not chunks:
396
+ return [], None, "⚠️ Could not split contract into chunks"
397
+
398
+ embeddings = build_embeddings(chunks)
399
+ if embeddings is None:
400
+ return chunks, None, "⚠️ Embedding model unavailable — chatbot will not work"
401
+
402
+ return (
403
+ chunks,
404
+ embeddings,
405
+ f"✅ Indexed {len(chunks)} chunks ({len(text)} chars) — Ready to chat!"
406
+ )
compare.py CHANGED
@@ -98,6 +98,28 @@ def compare_contracts(text_a, text_b, clauses_a=None, clauses_b=None):
98
  if clauses_b is None:
99
  clauses_b = _split_clauses(text_b)
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Build clause type maps
102
  type_map_a = defaultdict(list)
103
  type_map_b = defaultdict(list)
@@ -111,8 +133,9 @@ def compare_contracts(text_a, text_b, clauses_a=None, clauses_b=None):
111
  matched_b = set()
112
  modified = []
113
 
114
- SIMILARITY_THRESHOLD = 0.70
115
- MODIFIED_THRESHOLD = 0.40
 
116
 
117
  for i, ca in enumerate(clauses_a):
118
  best_sim = 0
@@ -181,12 +204,20 @@ def compare_contracts(text_a, text_b, clauses_a=None, clauses_b=None):
181
  risk_delta = "Similar risk profiles"
182
  risk_winner = "tie"
183
 
 
 
 
 
 
184
  comparison_method = "semantic (sentence embeddings)" if _embedder is not None else "lexical (string matching)"
185
 
186
  return {
187
  "alignment_score": round(alignment, 3),
188
  "contract_a_clauses": len(clauses_a),
189
  "contract_b_clauses": len(clauses_b),
 
 
 
190
  "added_clauses": [{"text": c[:200], "type": _extract_clause_type(c)} for c in added[:50]],
191
  "removed_clauses": [{"text": c[:200], "type": _extract_clause_type(c)} for c in removed[:50]],
192
  "modified_clauses": modified[:50],
 
98
  if clauses_b is None:
99
  clauses_b = _split_clauses(text_b)
100
 
101
+ # Fix 9: Detect contract types and flag cross-domain comparisons
102
+ _CONTRACT_TYPE_KEYWORDS = {
103
+ "employment": ["employee", "employer", "salary", "compensation", "benefits", "vacation", "severance", "at-will"],
104
+ "lease": ["landlord", "tenant", "rent", "premises", "lease", "occupancy", "security deposit", "eviction"],
105
+ "service": ["service provider", "customer", "SLA", "deliverables", "statement of work", "SOW"],
106
+ "nda": ["confidential", "non-disclosure", "disclosing party", "receiving party"],
107
+ "saas": ["subscription", "SaaS", "cloud", "uptime", "API", "data processing"],
108
+ "purchase": ["buyer", "seller", "purchase order", "goods", "shipment", "delivery"],
109
+ }
110
+
111
+ def _detect_contract_type(text):
112
+ text_lower = text.lower()
113
+ scores = {}
114
+ for ctype, keywords in _CONTRACT_TYPE_KEYWORDS.items():
115
+ scores[ctype] = sum(1 for kw in keywords if kw.lower() in text_lower)
116
+ best = max(scores, key=scores.get)
117
+ return best if scores[best] >= 2 else "general"
118
+
119
+ type_a = _detect_contract_type(text_a)
120
+ type_b = _detect_contract_type(text_b)
121
+ is_cross_domain = type_a != type_b and type_a != "general" and type_b != "general"
122
+
123
  # Build clause type maps
124
  type_map_a = defaultdict(list)
125
  type_map_b = defaultdict(list)
 
133
  matched_b = set()
134
  modified = []
135
 
136
+ # Fix 10: Raise thresholds to reject false "modified" matches
137
+ SIMILARITY_THRESHOLD = 0.75 # was 0.70 — too many false matches
138
+ MODIFIED_THRESHOLD = 0.55 # was 0.40 — "Good Reason" ≠ "Force Majeure"
139
 
140
  for i, ca in enumerate(clauses_a):
141
  best_sim = 0
 
204
  risk_delta = "Similar risk profiles"
205
  risk_winner = "tie"
206
 
207
+ # Fix 9: Cross-domain warning
208
+ if is_cross_domain:
209
+ risk_delta = f"Cross-domain comparison ({type_a} vs {type_b}) — risk delta not meaningful across different contract types"
210
+ risk_winner = "cross-domain"
211
+
212
  comparison_method = "semantic (sentence embeddings)" if _embedder is not None else "lexical (string matching)"
213
 
214
  return {
215
  "alignment_score": round(alignment, 3),
216
  "contract_a_clauses": len(clauses_a),
217
  "contract_b_clauses": len(clauses_b),
218
+ "contract_a_type": type_a,
219
+ "contract_b_type": type_b,
220
+ "is_cross_domain": is_cross_domain,
221
  "added_clauses": [{"text": c[:200], "type": _extract_clause_type(c)} for c in added[:50]],
222
  "removed_clauses": [{"text": c[:200], "type": _extract_clause_type(c)} for c in removed[:50]],
223
  "modified_clauses": modified[:50],
ml/ClauseGuard_DeBERTa_Training.ipynb ADDED
@@ -0,0 +1,1041 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "source": [
22
+ "# 🛡️ ClauseGuard v4 — DeBERTa-v3-large 2-Stage Training\n",
23
+ "\n",
24
+ "**Goal:** Train a production-grade contract clause classifier that replaces the current Legal-BERT-base (50% F1 → target 80-87% F1)\n",
25
+ "\n",
26
+ "## Architecture\n",
27
+ "| Setting | Value | Source |\n",
28
+ "|---------|-------|--------|\n",
29
+ "| Base model | `microsoft/deberta-v3-large` (435M params) | LexGLUE: outperforms Legal-BERT by 7-10pp |\n",
30
+ "| Max length | 512 tokens | MAUD paper: covers 72.4% of clauses without truncation |\n",
31
+ "| Loss function | Asymmetric Loss (γ-=4, clip=0.05) | ASL paper (2009.14119): +3-8pp on rare classes |\n",
32
+ "| Training | Full fine-tuning (no LoRA) | Full FT wins for encoder classification |\n",
33
+ "\n",
34
+ "## 2-Stage Training Pipeline\n",
35
+ "1. **Stage 1 — LEDGAR** (60K legal provisions, 100 classes): Teaches \"what types of contract clauses exist\"\n",
36
+ "2. **Stage 2 — CUAD** (41 CUAD classes): Target task with Asymmetric Loss for class imbalance\n",
37
+ "\n",
38
+ "**Runtime:** ~8-12 hours on T4 GPU (or ~4-6 hours on A100)\n",
39
+ "\n",
40
+ "**Before running:**\n",
41
+ "1. `Runtime` → `Change runtime type` → **T4 GPU**\n",
42
+ "2. `Runtime` → `Run all`\n",
43
+ "3. Paste your HuggingFace token when prompted"
44
+ ],
45
+ "metadata": {}
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "source": [
50
+ "## Step 1: Install Dependencies"
51
+ ],
52
+ "metadata": {}
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "source": [
57
+ "!pip install -q transformers datasets scikit-learn accelerate huggingface_hub torch\n",
58
+ "!pip install -q trackio # optional: experiment tracking"
59
+ ],
60
+ "metadata": {},
61
+ "execution_count": null,
62
+ "outputs": []
63
+ },
64
+ {
65
+ "cell_type": "markdown",
66
+ "source": [
67
+ "## Step 2: Login to HuggingFace Hub"
68
+ ],
69
+ "metadata": {}
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "source": [
74
+ "from huggingface_hub import login\n",
75
+ "login()"
76
+ ],
77
+ "metadata": {},
78
+ "execution_count": null,
79
+ "outputs": []
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "source": [
84
+ "## Step 3: Configuration"
85
+ ],
86
+ "metadata": {}
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "source": [
91
+ "import os\n",
92
+ "import torch\n",
93
+ "import numpy as np\n",
94
+ "\n",
95
+ "# ═══════════════════════════════════════════════════════════════\n",
96
+ "# CONFIGURATION — Edit these values\n",
97
+ "# ═══════════════════════════════════════════════════════════════\n",
98
+ "\n",
99
+ "BASE_MODEL = \"microsoft/deberta-v3-large\" # 435M params, MIT license\n",
100
+ "MAX_LENGTH = 512 # covers 72.4% of clauses\n",
101
+ "HUB_MODEL_ID = \"gaurv007/clauseguard-deberta-v3-large\" # ← your model repo\n",
102
+ "\n",
103
+ "# Stage 1: LEDGAR config\n",
104
+ "STAGE1_EPOCHS = 5 # LEDGAR is large, converges fast\n",
105
+ "STAGE1_LR = 2e-5\n",
106
+ "STAGE1_BATCH = 2 # T4 fp32: reduced for DeBERTa-v3 compatibility\n",
107
+ "STAGE1_GRAD_ACCUM = 16 # effective batch = 32 (2 * 16)\n",
108
+ "\n",
109
+ "# Stage 2: CUAD config \n",
110
+ "STAGE2_EPOCHS = 20\n",
111
+ "STAGE2_LR = 1e-5 # lower LR for fine-tuning pretrained model\n",
112
+ "STAGE2_BATCH = 2 # T4 fp32: reduced for DeBERTa-v3 compatibility\n",
113
+ "STAGE2_GRAD_ACCUM = 16 # effective batch = 32 (2 * 16)\n",
114
+ "EARLY_STOPPING_PATIENCE = 3\n",
115
+ "\n",
116
+ "# ASL hyperparameters (from arxiv 2009.14119)\n",
117
+ "ASL_GAMMA_POS = 0\n",
118
+ "ASL_GAMMA_NEG = 4\n",
119
+ "ASL_CLIP = 0.05\n",
120
+ "\n",
121
+ "# Weight decay (DeBERTa default)\n",
122
+ "WEIGHT_DECAY = 0.06\n",
123
+ "WARMUP_RATIO = 0.1\n",
124
+ "\n",
125
+ "SEED = 42\n",
126
+ "\n",
127
+ "# ═══════════════════════════════════════════════════════════════\n",
128
+ "\n",
129
+ "# CUAD 41 label names (must match class_id 0-40 in CUAD dataset)\n",
130
+ "CUAD_LABELS = [\n",
131
+ " \"Document Name\", # 0\n",
132
+ " \"Parties\", # 1\n",
133
+ " \"Agreement Date\", # 2\n",
134
+ " \"Effective Date\", # 3\n",
135
+ " \"Expiration Date\", # 4\n",
136
+ " \"Renewal Term\", # 5\n",
137
+ " \"Notice Period to Terminate Renewal\", # 6\n",
138
+ " \"Governing Law\", # 7\n",
139
+ " \"Most Favored Nation\", # 8\n",
140
+ " \"Non-Compete\", # 9\n",
141
+ " \"Exclusivity\", # 10\n",
142
+ " \"No-Solicit of Customers\", # 11\n",
143
+ " \"No-Solicit of Employees\", # 12\n",
144
+ " \"Non-Disparagement\", # 13\n",
145
+ " \"Termination for Convenience\", # 14\n",
146
+ " \"ROFR/ROFO/ROFN\", # 15\n",
147
+ " \"Change of Control\", # 16\n",
148
+ " \"Anti-Assignment\", # 17\n",
149
+ " \"Revenue/Profit Sharing\", # 18\n",
150
+ " \"Price Restriction\", # 19\n",
151
+ " \"Minimum Commitment\", # 20\n",
152
+ " \"Volume Restriction\", # 21\n",
153
+ " \"IP Ownership Assignment\", # 22\n",
154
+ " \"Joint IP Ownership\", # 23\n",
155
+ " \"License Grant\", # 24\n",
156
+ " \"Non-Transferable License\", # 25\n",
157
+ " \"Affiliate License-Licensor\", # 26\n",
158
+ " \"Affiliate License-Licensee\", # 27\n",
159
+ " \"Unlimited/All-You-Can-Eat License\", # 28\n",
160
+ " \"Irrevocable or Perpetual License\", # 29\n",
161
+ " \"Source Code Escrow\", # 30\n",
162
+ " \"Post-Termination Services\", # 31\n",
163
+ " \"Audit Rights\", # 32\n",
164
+ " \"Uncapped Liability\", # 33\n",
165
+ " \"Cap on Liability\", # 34\n",
166
+ " \"Liquidated Damages\", # 35\n",
167
+ " \"Warranty Duration\", # 36\n",
168
+ " \"Insurance\", # 37\n",
169
+ " \"Covenant Not to Sue\", # 38\n",
170
+ " \"Third Party Beneficiary\", # 39\n",
171
+ " \"Other\", # 40\n",
172
+ "]\n",
173
+ "\n",
174
+ "NUM_CUAD_LABELS = len(CUAD_LABELS) # 41\n",
175
+ "\n",
176
+ "print(f\"🛡️ ClauseGuard v4 Training Configuration\")\n",
177
+ "print(f\" Base model: {BASE_MODEL}\")\n",
178
+ "print(f\" Max length: {MAX_LENGTH}\")\n",
179
+ "print(f\" Hub model: {HUB_MODEL_ID}\")\n",
180
+ "print(f\" GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
181
+ "print(f\" VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\" if torch.cuda.is_available() else \"\")\n",
182
+ "print(f\" CUAD classes: {NUM_CUAD_LABELS}\")"
183
+ ],
184
+ "metadata": {},
185
+ "execution_count": null,
186
+ "outputs": []
187
+ },
188
+ {
189
+ "cell_type": "markdown",
190
+ "source": [
191
+ "## Step 4: Load Datasets"
192
+ ],
193
+ "metadata": {}
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "source": [
198
+ "from datasets import load_dataset, Dataset\n",
199
+ "import pandas as pd\n",
200
+ "from collections import Counter\n",
201
+ "\n",
202
+ "# ═══════════════════════════════════════════════════════════════\n",
203
+ "# Stage 1: LEDGAR (100 classes, single-label)\n",
204
+ "# ═══════════════════════════════════════════════════════════════\n",
205
+ "print(\"📚 Loading LEDGAR dataset...\")\n",
206
+ "ledgar = load_dataset(\"coastalcph/lex_glue\", \"ledgar\")\n",
207
+ "print(f\" Train: {len(ledgar['train']):,} | Val: {len(ledgar['validation']):,} | Test: {len(ledgar['test']):,}\")\n",
208
+ "num_ledgar_labels = ledgar['train'].features['label'].num_classes\n",
209
+ "print(f\" Classes: {num_ledgar_labels}\")\n",
210
+ "\n",
211
+ "# ═══════════════════════════════════════════════════════════════\n",
212
+ "# Stage 2: CUAD (41 classes — reformulated for classification)\n",
213
+ "# ═══════════════════════════════════════════════════════════════\n",
214
+ "print(\"\\n📚 Loading CUAD classification dataset...\")\n",
215
+ "cuad_raw = load_dataset(\"dvgodoy/CUAD_v1_Contract_Understanding_clause_classification\", split=\"train\")\n",
216
+ "print(f\" Total rows: {len(cuad_raw):,}\")\n",
217
+ "\n",
218
+ "# Analyze class distribution\n",
219
+ "class_counts = Counter(cuad_raw['class_id'])\n",
220
+ "print(f\" Unique classes: {len(class_counts)}\")\n",
221
+ "print(f\" \\n Class distribution:\")\n",
222
+ "for cid in sorted(class_counts.keys()):\n",
223
+ " label_name = CUAD_LABELS[cid] if cid < len(CUAD_LABELS) else f\"Unknown-{cid}\"\n",
224
+ " count = class_counts[cid]\n",
225
+ " bar = '█' * min(50, count // 10)\n",
226
+ " print(f\" {cid:2d} {label_name:40s} {count:5d} {bar}\")"
227
+ ],
228
+ "metadata": {},
229
+ "execution_count": null,
230
+ "outputs": []
231
+ },
232
+ {
233
+ "cell_type": "markdown",
234
+ "source": [
235
+ "## Step 5: Prepare CUAD Train/Val/Test Splits"
236
+ ],
237
+ "metadata": {}
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "source": [
242
+ "from sklearn.model_selection import train_test_split\n",
243
+ "\n",
244
+ "# CUAD only has train split — create val/test by splitting by file_name\n",
245
+ "# (so no data leakage between contracts)\n",
246
+ "cuad_df = cuad_raw.to_pandas()\n",
247
+ "\n",
248
+ "# Get unique file names\n",
249
+ "unique_files = cuad_df['file_name'].unique()\n",
250
+ "print(f\"Unique contracts: {len(unique_files)}\")\n",
251
+ "\n",
252
+ "# Split files 80/10/10\n",
253
+ "train_files, test_files = train_test_split(unique_files, test_size=0.2, random_state=SEED)\n",
254
+ "val_files, test_files = train_test_split(test_files, test_size=0.5, random_state=SEED)\n",
255
+ "\n",
256
+ "cuad_train_df = cuad_df[cuad_df['file_name'].isin(train_files)]\n",
257
+ "cuad_val_df = cuad_df[cuad_df['file_name'].isin(val_files)]\n",
258
+ "cuad_test_df = cuad_df[cuad_df['file_name'].isin(test_files)]\n",
259
+ "\n",
260
+ "print(f\"CUAD splits — Train: {len(cuad_train_df)} | Val: {len(cuad_val_df)} | Test: {len(cuad_test_df)}\")\n",
261
+ "print(f\"Train contracts: {len(train_files)} | Val contracts: {len(val_files)} | Test contracts: {len(test_files)}\")\n",
262
+ "\n",
263
+ "# Convert to HF Dataset\n",
264
+ "cuad_train = Dataset.from_pandas(cuad_train_df.reset_index(drop=True))\n",
265
+ "cuad_val = Dataset.from_pandas(cuad_val_df.reset_index(drop=True))\n",
266
+ "cuad_test = Dataset.from_pandas(cuad_test_df.reset_index(drop=True))\n",
267
+ "\n",
268
+ "# Verify class distribution in each split\n",
269
+ "for name, ds in [(\"Train\", cuad_train), (\"Val\", cuad_val), (\"Test\", cuad_test)]:\n",
270
+ " counts = Counter(ds['class_id'])\n",
271
+ " empty_classes = [i for i in range(NUM_CUAD_LABELS) if counts.get(i, 0) == 0]\n",
272
+ " print(f\" {name}: {len(ds)} rows, {len(counts)} classes present, {len(empty_classes)} classes missing: {empty_classes[:5]}...\")"
273
+ ],
274
+ "metadata": {},
275
+ "execution_count": null,
276
+ "outputs": []
277
+ },
278
+ {
279
+ "cell_type": "markdown",
280
+ "source": [
281
+ "## Step 6: Tokenizer & Preprocessing"
282
+ ],
283
+ "metadata": {}
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "source": [
288
+ "from transformers import AutoTokenizer\n",
289
+ "\n",
290
+ "print(f\"Loading tokenizer: {BASE_MODEL}\")\n",
291
+ "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n",
292
+ "\n",
293
+ "# ── LEDGAR preprocessing (single-label) ──\n",
294
+ "def preprocess_ledgar(examples):\n",
295
+ " tokenized = tokenizer(\n",
296
+ " examples[\"text\"],\n",
297
+ " truncation=True,\n",
298
+ " max_length=MAX_LENGTH,\n",
299
+ " padding=False,\n",
300
+ " )\n",
301
+ " tokenized[\"labels\"] = examples[\"label\"] # int label for CrossEntropy\n",
302
+ " return tokenized\n",
303
+ "\n",
304
+ "# ── CUAD preprocessing (single-label per clause, 41 classes) ──\n",
305
+ "def preprocess_cuad(examples):\n",
306
+ " tokenized = tokenizer(\n",
307
+ " examples[\"clause\"],\n",
308
+ " truncation=True,\n",
309
+ " max_length=MAX_LENGTH,\n",
310
+ " padding=False,\n",
311
+ " )\n",
312
+ " tokenized[\"labels\"] = examples[\"class_id\"] # int label for CrossEntropy + ASL\n",
313
+ " return tokenized\n",
314
+ "\n",
315
+ "print(\"Tokenizing LEDGAR...\")\n",
316
+ "ledgar_tokenized = ledgar.map(\n",
317
+ " preprocess_ledgar, batched=True,\n",
318
+ " remove_columns=ledgar[\"train\"].column_names,\n",
319
+ " desc=\"Tokenizing LEDGAR\"\n",
320
+ ")\n",
321
+ "\n",
322
+ "print(\"Tokenizing CUAD...\")\n",
323
+ "cuad_train_tok = cuad_train.map(\n",
324
+ " preprocess_cuad, batched=True,\n",
325
+ " remove_columns=cuad_train.column_names,\n",
326
+ " desc=\"Tokenizing CUAD train\"\n",
327
+ ")\n",
328
+ "cuad_val_tok = cuad_val.map(\n",
329
+ " preprocess_cuad, batched=True,\n",
330
+ " remove_columns=cuad_val.column_names,\n",
331
+ " desc=\"Tokenizing CUAD val\"\n",
332
+ ")\n",
333
+ "cuad_test_tok = cuad_test.map(\n",
334
+ " preprocess_cuad, batched=True,\n",
335
+ " remove_columns=cuad_test.column_names,\n",
336
+ " desc=\"Tokenizing CUAD test\"\n",
337
+ ")\n",
338
+ "\n",
339
+ "# Check token lengths\n",
340
+ "train_lengths = [len(x) for x in cuad_train_tok['input_ids']]\n",
341
+ "print(f\"\\n📊 CUAD token length stats:\")\n",
342
+ "print(f\" Mean: {np.mean(train_lengths):.0f} | Median: {np.median(train_lengths):.0f}\")\n",
343
+ "print(f\" 95th pct: {np.percentile(train_lengths, 95):.0f} | Max: {max(train_lengths)}\")\n",
344
+ "print(f\" Truncated (>512): {sum(1 for l in train_lengths if l >= MAX_LENGTH)} ({sum(1 for l in train_lengths if l >= MAX_LENGTH)/len(train_lengths)*100:.1f}%)\")\n",
345
+ "print(\"✅ Tokenization complete!\")"
346
+ ],
347
+ "metadata": {},
348
+ "execution_count": null,
349
+ "outputs": []
350
+ },
351
+ {
352
+ "cell_type": "markdown",
353
+ "source": [
354
+ "## Step 7: Asymmetric Loss Function\n",
355
+ "\n",
356
+ "From [Asymmetric Loss For Multi-Label Classification](https://arxiv.org/abs/2009.14119) (ICCV 2021).\n",
357
+ "\n",
358
+ "Key idea: Down-weight easy negatives more aggressively than positives. Critical for CUAD where most labels are negative for any given clause."
359
+ ],
360
+ "metadata": {}
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "source": [
365
+ "import torch\n",
366
+ "import torch.nn as nn\n",
367
+ "import torch.nn.functional as F\n",
368
+ "\n",
369
+ "\n",
370
+ "class AsymmetricLoss(nn.Module):\n",
371
+ " \"\"\"\n",
372
+ " Asymmetric Loss from arxiv:2009.14119.\n",
373
+ " \n",
374
+ " For multi-class (single-label) classification with class imbalance:\n",
375
+ " We use the multi-class variant — apply focal-style re-weighting\n",
376
+ " to cross-entropy, with different gamma for correct vs incorrect classes.\n",
377
+ " \n",
378
+ " For multi-label (multi-hot) classification:\n",
379
+ " L+ = (1-p)^γ+ * log(p)\n",
380
+ " L- = (pm)^γ- * log(1-pm), pm = max(p - m, 0)\n",
381
+ " \"\"\"\n",
382
+ " def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, eps=1e-8,\n",
383
+ " num_classes=None, class_weights=None, mode=\"multi_class\"):\n",
384
+ " super().__init__()\n",
385
+ " self.gamma_pos = gamma_pos\n",
386
+ " self.gamma_neg = gamma_neg\n",
387
+ " self.clip = clip\n",
388
+ " self.eps = eps\n",
389
+ " self.mode = mode\n",
390
+ " \n",
391
+ " # Optional class weights for severe imbalance\n",
392
+ " if class_weights is not None:\n",
393
+ " self.register_buffer('class_weights', torch.tensor(class_weights, dtype=torch.float32))\n",
394
+ " else:\n",
395
+ " self.class_weights = None\n",
396
+ "\n",
397
+ " def forward(self, logits, targets):\n",
398
+ " if self.mode == \"multi_label\":\n",
399
+ " return self._multi_label_loss(logits, targets)\n",
400
+ " else:\n",
401
+ " return self._multi_class_loss(logits, targets)\n",
402
+ " \n",
403
+ " def _multi_class_loss(self, logits, targets):\n",
404
+ " \"\"\"Focal-style cross-entropy with asymmetric gamma for single-label classification.\"\"\"\n",
405
+ " # Standard cross-entropy with class weights\n",
406
+ " if self.class_weights is not None:\n",
407
+ " ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')\n",
408
+ " else:\n",
409
+ " ce_loss = F.cross_entropy(logits, targets, reduction='none')\n",
410
+ " \n",
411
+ " # Apply focal modulation\n",
412
+ " probs = F.softmax(logits, dim=-1)\n",
413
+ " # Get probability of the correct class\n",
414
+ " p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)\n",
415
+ " \n",
416
+ " # Focal weight: (1 - p_t)^gamma\n",
417
+ " # Use gamma_neg for hard examples (low p_t), gamma_pos for easy ones\n",
418
+ " focal_weight = (1 - p_t) ** self.gamma_neg\n",
419
+ " \n",
420
+ " loss = focal_weight * ce_loss\n",
421
+ " return loss.mean()\n",
422
+ "\n",
423
+ " def _multi_label_loss(self, logits, targets):\n",
424
+ " \"\"\"Full ASL for multi-label classification.\"\"\"\n",
425
+ " p = torch.sigmoid(logits)\n",
426
+ " \n",
427
+ " if self.clip is not None and self.clip > 0:\n",
428
+ " p_m = torch.clamp(p - self.clip, min=0)\n",
429
+ " else:\n",
430
+ " p_m = p\n",
431
+ " \n",
432
+ " loss_pos = targets * (1 - p) ** self.gamma_pos * torch.log(p + self.eps)\n",
433
+ " loss_neg = (1 - targets) * p_m ** self.gamma_neg * torch.log(1 - p_m + self.eps)\n",
434
+ " \n",
435
+ " loss = -(loss_pos + loss_neg)\n",
436
+ " return loss.mean()\n",
437
+ "\n",
438
+ "\n",
439
+ "print(\"✅ AsymmetricLoss defined\")\n",
440
+ "print(f\" γ+ = {ASL_GAMMA_POS}, γ- = {ASL_GAMMA_NEG}, clip = {ASL_CLIP}\")"
441
+ ],
442
+ "metadata": {},
443
+ "execution_count": null,
444
+ "outputs": []
445
+ },
446
+ {
447
+ "cell_type": "markdown",
448
+ "source": [
449
+ "## Step 8: Custom Trainer with ASL"
450
+ ],
451
+ "metadata": {}
452
+ },
453
+ {
454
+ "cell_type": "code",
455
+ "source": [
456
+ "from transformers import Trainer\n",
457
+ "\n",
458
+ "\n",
459
+ "class ASLTrainer(Trainer):\n",
460
+ " \"\"\"Custom Trainer that uses Asymmetric Loss instead of standard CrossEntropy.\"\"\"\n",
461
+ " \n",
462
+ " def __init__(self, *args, asl_loss_fn=None, **kwargs):\n",
463
+ " super().__init__(*args, **kwargs)\n",
464
+ " self.asl = asl_loss_fn\n",
465
+ "\n",
466
+ " def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n",
467
+ " labels = inputs.pop(\"labels\")\n",
468
+ " outputs = model(**inputs)\n",
469
+ " logits = outputs.logits\n",
470
+ " \n",
471
+ " if self.asl is not None:\n",
472
+ " loss = self.asl(logits, labels)\n",
473
+ " else:\n",
474
+ " # Fallback to standard cross-entropy\n",
475
+ " loss = F.cross_entropy(logits, labels)\n",
476
+ " \n",
477
+ " return (loss, outputs) if return_outputs else loss\n",
478
+ "\n",
479
+ "\n",
480
+ "print(\"✅ ASLTrainer defined\")"
481
+ ],
482
+ "metadata": {},
483
+ "execution_count": null,
484
+ "outputs": []
485
+ },
486
+ {
487
+ "cell_type": "markdown",
488
+ "source": [
489
+ "## Step 9: Metrics"
490
+ ],
491
+ "metadata": {}
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "source": [
496
+ "from sklearn.metrics import f1_score, precision_score, recall_score, classification_report\n",
497
+ "\n",
498
+ "\n",
499
+ "def compute_metrics_single_label(eval_pred):\n",
500
+ " \"\"\"Metrics for single-label classification (LEDGAR & CUAD).\"\"\"\n",
501
+ " logits, labels = eval_pred.predictions, eval_pred.label_ids\n",
502
+ " preds = np.argmax(logits, axis=-1)\n",
503
+ " \n",
504
+ " micro_f1 = f1_score(labels, preds, average=\"micro\", zero_division=0)\n",
505
+ " macro_f1 = f1_score(labels, preds, average=\"macro\", zero_division=0)\n",
506
+ " weighted_f1 = f1_score(labels, preds, average=\"weighted\", zero_division=0)\n",
507
+ " accuracy = (preds == labels).mean()\n",
508
+ " \n",
509
+ " return {\n",
510
+ " \"accuracy\": accuracy,\n",
511
+ " \"micro_f1\": micro_f1,\n",
512
+ " \"macro_f1\": macro_f1,\n",
513
+ " \"weighted_f1\": weighted_f1,\n",
514
+ " }\n",
515
+ "\n",
516
+ "\n",
517
+ "def compute_metrics_cuad_detailed(eval_pred):\n",
518
+ " \"\"\"Detailed metrics for CUAD — includes per-class F1.\"\"\"\n",
519
+ " logits, labels = eval_pred.predictions, eval_pred.label_ids\n",
520
+ " preds = np.argmax(logits, axis=-1)\n",
521
+ " \n",
522
+ " micro_f1 = f1_score(labels, preds, average=\"micro\", zero_division=0)\n",
523
+ " macro_f1 = f1_score(labels, preds, average=\"macro\", zero_division=0)\n",
524
+ " weighted_f1 = f1_score(labels, preds, average=\"weighted\", zero_division=0)\n",
525
+ " accuracy = (preds == labels).mean()\n",
526
+ " \n",
527
+ " # Per-class F1\n",
528
+ " per_class_f1 = f1_score(labels, preds, average=None, zero_division=0)\n",
529
+ " class_metrics = {}\n",
530
+ " for i, f1_val in enumerate(per_class_f1):\n",
531
+ " if i < len(CUAD_LABELS):\n",
532
+ " # Truncate label name for cleaner logging\n",
533
+ " safe_name = CUAD_LABELS[i][:20].replace(\" \", \"_\").replace(\"/\", \"_\")\n",
534
+ " class_metrics[f\"f1_{safe_name}\"] = float(f1_val)\n",
535
+ " \n",
536
+ " return {\n",
537
+ " \"accuracy\": accuracy,\n",
538
+ " \"micro_f1\": micro_f1,\n",
539
+ " \"macro_f1\": macro_f1,\n",
540
+ " \"weighted_f1\": weighted_f1,\n",
541
+ " **class_metrics,\n",
542
+ " }\n",
543
+ "\n",
544
+ "\n",
545
+ "print(\"✅ Metrics functions defined\")"
546
+ ],
547
+ "metadata": {},
548
+ "execution_count": null,
549
+ "outputs": []
550
+ },
551
+ {
552
+ "cell_type": "markdown",
553
+ "source": [
554
+ "---\n",
555
+ "# 🏋️ STAGE 1: Pre-fine-tune on LEDGAR\n",
556
+ "\n",
557
+ "**Goal:** Teach DeBERTa-v3-large what types of contract clauses exist (100 classes, ~60K examples).\n",
558
+ "\n",
559
+ "This stage uses standard cross-entropy loss since LEDGAR is well-balanced.\n",
560
+ "\n",
561
+ "**Expected:** ~85-90% micro-F1 after 3-5 epochs (~3-5 hours on T4, ~1-2 hours on A100)"
562
+ ],
563
+ "metadata": {}
564
+ },
565
+ {
566
+ "cell_type": "code",
567
+ "source": [
568
+ "from transformers import (\n",
569
+ " AutoConfig,\n",
570
+ " AutoModelForSequenceClassification,\n",
571
+ " TrainingArguments,\n",
572
+ " DataCollatorWithPadding,\n",
573
+ " EarlyStoppingCallback,\n",
574
+ ")\n",
575
+ "\n",
576
+ "print(f\"🏋️ STAGE 1: Pre-fine-tune on LEDGAR ({num_ledgar_labels} classes)\")\n",
577
+ "print(f\" Loading {BASE_MODEL}...\")\n",
578
+ "\n",
579
+ "# Load model for Stage 1 (100 classes, single-label)\n",
580
+ "stage1_model = AutoModelForSequenceClassification.from_pretrained(\n",
581
+ " BASE_MODEL,\n",
582
+ " num_labels=num_ledgar_labels,\n",
583
+ " problem_type=\"single_label_classification\",\n",
584
+ " ignore_mismatched_sizes=True,\n",
585
+ ")\n",
586
+ "\n",
587
+ "total_params = sum(p.numel() for p in stage1_model.parameters())\n",
588
+ "trainable_params = sum(p.numel() for p in stage1_model.parameters() if p.requires_grad)\n",
589
+ "print(f\" Total parameters: {total_params:,}\")\n",
590
+ "print(f\" Trainable parameters: {trainable_params:,}\")\n",
591
+ "\n",
592
+ "stage1_args = TrainingArguments(\n",
593
+ " output_dir=\"./stage1_ledgar\",\n",
594
+ " num_train_epochs=STAGE1_EPOCHS,\n",
595
+ " per_device_train_batch_size=STAGE1_BATCH,\n",
596
+ " per_device_eval_batch_size=4,\n",
597
+ " gradient_accumulation_steps=STAGE1_GRAD_ACCUM,\n",
598
+ " learning_rate=STAGE1_LR,\n",
599
+ " weight_decay=WEIGHT_DECAY,\n",
600
+ " warmup_ratio=WARMUP_RATIO,\n",
601
+ " lr_scheduler_type=\"cosine\",\n",
602
+ " eval_strategy=\"epoch\",\n",
603
+ " save_strategy=\"epoch\",\n",
604
+ " save_total_limit=2,\n",
605
+ " load_best_model_at_end=True,\n",
606
+ " metric_for_best_model=\"macro_f1\",\n",
607
+ " greater_is_better=True,\n",
608
+ " bf16=False, # DeBERTa-v3 breaks with fp16 gradient scaler; fp32 is safest on T4\n",
609
+ " fp16=False,\n",
610
+ " logging_strategy=\"steps\",\n",
611
+ " logging_steps=50,\n",
612
+ " logging_first_step=True,\n",
613
+ " disable_tqdm=False,\n",
614
+ " report_to=\"none\",\n",
615
+ " dataloader_num_workers=2,\n",
616
+ " seed=SEED,\n",
617
+ " gradient_checkpointing=True, # Critical for T4 (16GB VRAM)\n",
618
+ ")\n",
619
+ "\n",
620
+ "stage1_trainer = Trainer(\n",
621
+ " model=stage1_model,\n",
622
+ " args=stage1_args,\n",
623
+ " train_dataset=ledgar_tokenized[\"train\"],\n",
624
+ " eval_dataset=ledgar_tokenized[\"validation\"],\n",
625
+ " processing_class=tokenizer,\n",
626
+ " data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
627
+ " compute_metrics=compute_metrics_single_label,\n",
628
+ " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],\n",
629
+ ")\n",
630
+ "\n",
631
+ "print(\"\\n🚀 Starting Stage 1 training...\")\n",
632
+ "stage1_result = stage1_trainer.train()\n",
633
+ "print(f\"\\n✅ Stage 1 complete! Loss: {stage1_result.training_loss:.4f}\")"
634
+ ],
635
+ "metadata": {},
636
+ "execution_count": null,
637
+ "outputs": []
638
+ },
639
+ {
640
+ "cell_type": "code",
641
+ "source": [
642
+ "# Evaluate Stage 1 on LEDGAR test set\n",
643
+ "print(\"📊 Stage 1 — LEDGAR Test Evaluation\")\n",
644
+ "stage1_test = stage1_trainer.evaluate(ledgar_tokenized[\"test\"])\n",
645
+ "print(f\" Accuracy: {stage1_test['eval_accuracy']:.4f}\")\n",
646
+ "print(f\" Micro-F1: {stage1_test['eval_micro_f1']:.4f}\")\n",
647
+ "print(f\" Macro-F1: {stage1_test['eval_macro_f1']:.4f}\")\n",
648
+ "print(f\" Weighted-F1: {stage1_test['eval_weighted_f1']:.4f}\")\n",
649
+ "\n",
650
+ "# Save Stage 1 checkpoint\n",
651
+ "STAGE1_CHECKPOINT = \"./stage1_ledgar_best\"\n",
652
+ "stage1_trainer.save_model(STAGE1_CHECKPOINT)\n",
653
+ "tokenizer.save_pretrained(STAGE1_CHECKPOINT)\n",
654
+ "print(f\"\\n💾 Stage 1 checkpoint saved to {STAGE1_CHECKPOINT}\")"
655
+ ],
656
+ "metadata": {},
657
+ "execution_count": null,
658
+ "outputs": []
659
+ },
660
+ {
661
+ "cell_type": "markdown",
662
+ "source": [
663
+ "---\n",
664
+ "# 🏋️ STAGE 2: Fine-tune on CUAD 41-class with Asymmetric Loss\n",
665
+ "\n",
666
+ "**Goal:** Learn the 41 CUAD contract clause types from the Stage 1 backbone.\n",
667
+ "\n",
668
+ "Key improvements over current ClauseGuard:\n",
669
+ "- DeBERTa-v3-large backbone pre-trained on LEDGAR (Stage 1)\n",
670
+ "- 512 tokens (vs 256) — captures full clause content\n",
671
+ "- Asymmetric Loss for class imbalance\n",
672
+ "- Full fine-tuning (no LoRA bottleneck)\n",
673
+ "\n",
674
+ "**Expected:** 75-87% macro-F1 after 10-20 epochs (~5-8 hours on T4, ~2-4 hours on A100)"
675
+ ],
676
+ "metadata": {}
677
+ },
678
+ {
679
+ "cell_type": "code",
680
+ "source": [
681
+ "# Free Stage 1 model memory before loading Stage 2\n",
682
+ "del stage1_model, stage1_trainer\n",
683
+ "torch.cuda.empty_cache()\n",
684
+ "import gc; gc.collect()\n",
685
+ "\n",
686
+ "print(f\"🏋️ STAGE 2: Fine-tune on CUAD ({NUM_CUAD_LABELS} classes) with ASL\")\n",
687
+ "\n",
688
+ "# Load Stage 1 checkpoint with new head (100 → 41 classes)\n",
689
+ "stage2_model = AutoModelForSequenceClassification.from_pretrained(\n",
690
+ " STAGE1_CHECKPOINT,\n",
691
+ " num_labels=NUM_CUAD_LABELS,\n",
692
+ " ignore_mismatched_sizes=True, # classifier head: 100 → 41\n",
693
+ " problem_type=\"single_label_classification\",\n",
694
+ ")\n",
695
+ "\n",
696
+ "print(f\" Loaded Stage 1 backbone with new {NUM_CUAD_LABELS}-class head\")\n",
697
+ "print(f\" Parameters: {sum(p.numel() for p in stage2_model.parameters()):,}\")\n",
698
+ "\n",
699
+ "# Compute class weights from training distribution\n",
700
+ "train_class_counts = Counter(cuad_train_tok['labels'])\n",
701
+ "total_samples = sum(train_class_counts.values())\n",
702
+ "class_weights = []\n",
703
+ "for i in range(NUM_CUAD_LABELS):\n",
704
+ " count = train_class_counts.get(i, 1) # avoid div by zero\n",
705
+ " # Inverse frequency weighting, capped\n",
706
+ " weight = min(10.0, total_samples / (NUM_CUAD_LABELS * count))\n",
707
+ " class_weights.append(weight)\n",
708
+ "\n",
709
+ "print(f\" Class weight range: [{min(class_weights):.2f}, {max(class_weights):.2f}]\")\n",
710
+ "\n",
711
+ "# Create ASL loss\n",
712
+ "asl_loss = AsymmetricLoss(\n",
713
+ " gamma_pos=ASL_GAMMA_POS,\n",
714
+ " gamma_neg=ASL_GAMMA_NEG,\n",
715
+ " clip=ASL_CLIP,\n",
716
+ " num_classes=NUM_CUAD_LABELS,\n",
717
+ " class_weights=class_weights,\n",
718
+ " mode=\"multi_class\", # single-label per clause\n",
719
+ ")\n",
720
+ "# Move to GPU\n",
721
+ "if torch.cuda.is_available():\n",
722
+ " asl_loss = asl_loss.cuda()\n",
723
+ "\n",
724
+ "stage2_args = TrainingArguments(\n",
725
+ " output_dir=\"./stage2_cuad\",\n",
726
+ " num_train_epochs=STAGE2_EPOCHS,\n",
727
+ " per_device_train_batch_size=STAGE2_BATCH,\n",
728
+ " per_device_eval_batch_size=4,\n",
729
+ " gradient_accumulation_steps=STAGE2_GRAD_ACCUM,\n",
730
+ " learning_rate=STAGE2_LR,\n",
731
+ " weight_decay=WEIGHT_DECAY,\n",
732
+ " warmup_ratio=WARMUP_RATIO,\n",
733
+ " lr_scheduler_type=\"cosine\",\n",
734
+ " eval_strategy=\"epoch\",\n",
735
+ " save_strategy=\"epoch\",\n",
736
+ " save_total_limit=3,\n",
737
+ " load_best_model_at_end=True,\n",
738
+ " metric_for_best_model=\"macro_f1\",\n",
739
+ " greater_is_better=True,\n",
740
+ " bf16=False, # DeBERTa-v3 breaks with fp16 gradient scaler; fp32 is safest on T4\n",
741
+ " fp16=False,\n",
742
+ " logging_strategy=\"steps\",\n",
743
+ " logging_steps=25,\n",
744
+ " logging_first_step=True,\n",
745
+ " disable_tqdm=False,\n",
746
+ " report_to=\"none\",\n",
747
+ " push_to_hub=True,\n",
748
+ " hub_model_id=HUB_MODEL_ID,\n",
749
+ " dataloader_num_workers=2,\n",
750
+ " seed=SEED,\n",
751
+ " gradient_checkpointing=True,\n",
752
+ ")\n",
753
+ "\n",
754
+ "stage2_trainer = ASLTrainer(\n",
755
+ " model=stage2_model,\n",
756
+ " args=stage2_args,\n",
757
+ " asl_loss_fn=asl_loss,\n",
758
+ " train_dataset=cuad_train_tok,\n",
759
+ " eval_dataset=cuad_val_tok,\n",
760
+ " processing_class=tokenizer,\n",
761
+ " data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
762
+ " compute_metrics=compute_metrics_cuad_detailed,\n",
763
+ " callbacks=[EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE)],\n",
764
+ ")\n",
765
+ "\n",
766
+ "print(\"\\n🚀 Starting Stage 2 training with Asymmetric Loss...\")\n",
767
+ "stage2_result = stage2_trainer.train()\n",
768
+ "print(f\"\\n✅ Stage 2 complete! Loss: {stage2_result.training_loss:.4f}\")"
769
+ ],
770
+ "metadata": {},
771
+ "execution_count": null,
772
+ "outputs": []
773
+ },
774
+ {
775
+ "cell_type": "markdown",
776
+ "source": [
777
+ "## Step 10: Evaluate Stage 2 on CUAD Test Set"
778
+ ],
779
+ "metadata": {}
780
+ },
781
+ {
782
+ "cell_type": "code",
783
+ "source": [
784
+ "print(\"📊 Stage 2 — CUAD Test Evaluation\")\n",
785
+ "test_results = stage2_trainer.evaluate(cuad_test_tok)\n",
786
+ "\n",
787
+ "print(f\"\\n{'='*60}\")\n",
788
+ "print(f\" CUAD TEST RESULTS (DeBERTa-v3-large + LEDGAR + ASL)\")\n",
789
+ "print(f\"{'='*60}\")\n",
790
+ "print(f\" Accuracy: {test_results['eval_accuracy']:.4f}\")\n",
791
+ "print(f\" Micro-F1: {test_results['eval_micro_f1']:.4f}\")\n",
792
+ "print(f\" Macro-F1: {test_results['eval_macro_f1']:.4f}\")\n",
793
+ "print(f\" Weighted-F1: {test_results['eval_weighted_f1']:.4f}\")\n",
794
+ "print(f\"{'='*60}\")\n",
795
+ "\n",
796
+ "# Per-class F1 report\n",
797
+ "print(f\"\\n Per-class F1 scores:\")\n",
798
+ "print(f\" {'Class':<42s} {'F1':>6s}\")\n",
799
+ "print(f\" {'-'*48}\")\n",
800
+ "\n",
801
+ "zero_f1_classes = []\n",
802
+ "for i, label_name in enumerate(CUAD_LABELS):\n",
803
+ " safe_name = label_name[:20].replace(\" \", \"_\").replace(\"/\", \"_\")\n",
804
+ " key = f\"eval_f1_{safe_name}\"\n",
805
+ " f1_val = test_results.get(key, 0.0)\n",
806
+ " bar = '█' * int(f1_val * 30)\n",
807
+ " status = \"\" if f1_val > 0 else \" ← ZERO\"\n",
808
+ " print(f\" {i:2d} {label_name:<40s} {f1_val:.4f} {bar}{status}\")\n",
809
+ " if f1_val == 0:\n",
810
+ " zero_f1_classes.append(label_name)\n",
811
+ "\n",
812
+ "print(f\"\\n Classes with zero F1: {len(zero_f1_classes)}\")\n",
813
+ "if zero_f1_classes:\n",
814
+ " for c in zero_f1_classes:\n",
815
+ " print(f\" ⚠️ {c}\")"
816
+ ],
817
+ "metadata": {},
818
+ "execution_count": null,
819
+ "outputs": []
820
+ },
821
+ {
822
+ "cell_type": "markdown",
823
+ "source": [
824
+ "## Step 11: Full Classification Report"
825
+ ],
826
+ "metadata": {}
827
+ },
828
+ {
829
+ "cell_type": "code",
830
+ "source": [
831
+ "# Generate full sklearn classification report\n",
832
+ "from sklearn.metrics import classification_report\n",
833
+ "\n",
834
+ "# Get predictions on test set\n",
835
+ "preds_output = stage2_trainer.predict(cuad_test_tok)\n",
836
+ "preds = np.argmax(preds_output.predictions, axis=-1)\n",
837
+ "labels = preds_output.label_ids\n",
838
+ "\n",
839
+ "# Only include labels that appear in test set\n",
840
+ "present_labels = sorted(set(labels) | set(preds))\n",
841
+ "target_names = [CUAD_LABELS[i] if i < len(CUAD_LABELS) else f\"Class-{i}\" for i in present_labels]\n",
842
+ "\n",
843
+ "report = classification_report(\n",
844
+ " labels, preds,\n",
845
+ " labels=present_labels,\n",
846
+ " target_names=target_names,\n",
847
+ " zero_division=0,\n",
848
+ " digits=4,\n",
849
+ ")\n",
850
+ "print(\"\\n📊 Full Classification Report:\")\n",
851
+ "print(report)"
852
+ ],
853
+ "metadata": {},
854
+ "execution_count": null,
855
+ "outputs": []
856
+ },
857
+ {
858
+ "cell_type": "markdown",
859
+ "source": [
860
+ "## Step 12: Push Final Model to Hub"
861
+ ],
862
+ "metadata": {}
863
+ },
864
+ {
865
+ "cell_type": "code",
866
+ "source": [
867
+ "# Save model with proper label mapping\n",
868
+ "stage2_model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)}\n",
869
+ "stage2_model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)}\n",
870
+ "\n",
871
+ "# Save locally\n",
872
+ "FINAL_DIR = \"./clauseguard-deberta-final\"\n",
873
+ "stage2_trainer.save_model(FINAL_DIR)\n",
874
+ "tokenizer.save_pretrained(FINAL_DIR)\n",
875
+ "\n",
876
+ "# Push to Hub\n",
877
+ "print(f\"\\n☁️ Pushing model to Hub: {HUB_MODEL_ID}\")\n",
878
+ "stage2_trainer.push_to_hub(\n",
879
+ " commit_message=(\n",
880
+ " f\"ClauseGuard v4: DeBERTa-v3-large 2-stage (LEDGAR→CUAD) with ASL\\n\"\n",
881
+ " f\"CUAD Test: micro-F1={test_results['eval_micro_f1']:.4f}, \"\n",
882
+ " f\"macro-F1={test_results['eval_macro_f1']:.4f}\"\n",
883
+ " )\n",
884
+ ")\n",
885
+ "\n",
886
+ "print(f\"\\n✅ Model pushed to: https://huggingface.co/{HUB_MODEL_ID}\")"
887
+ ],
888
+ "metadata": {},
889
+ "execution_count": null,
890
+ "outputs": []
891
+ },
892
+ {
893
+ "cell_type": "markdown",
894
+ "source": [
895
+ "## Step 13: Test the Model on Sample Clauses"
896
+ ],
897
+ "metadata": {}
898
+ },
899
+ {
900
+ "cell_type": "code",
901
+ "source": [
902
+ "from transformers import pipeline as hf_pipeline\n",
903
+ "\n",
904
+ "# Load the trained model for inference\n",
905
+ "classifier = hf_pipeline(\n",
906
+ " \"text-classification\",\n",
907
+ " model=stage2_model,\n",
908
+ " tokenizer=tokenizer,\n",
909
+ " top_k=5, # return top 5 predictions\n",
910
+ " device=0 if torch.cuda.is_available() else -1,\n",
911
+ ")\n",
912
+ "\n",
913
+ "test_clauses = [\n",
914
+ " # High-risk clauses\n",
915
+ " \"The Company may terminate this Agreement at any time, with or without cause, upon written notice to the other party.\",\n",
916
+ " \"In no event shall the Company be liable for any indirect, incidental, special, or consequential damages arising out of this Agreement.\",\n",
917
+ " \"All intellectual property developed during the term of this Agreement shall be owned exclusively by the Company.\",\n",
918
+ " \"This Agreement shall be governed by and construed in accordance with the laws of the State of Delaware.\",\n",
919
+ " \"Any disputes arising out of this Agreement shall be resolved through binding arbitration in New York.\",\n",
920
+ " \"The Employee agrees not to compete with the Company for a period of two (2) years following termination.\",\n",
921
+ " # Neutral clauses\n",
922
+ " \"This Agreement shall be effective as of January 1, 2024.\",\n",
923
+ " \"The initial term of this Agreement shall be three (3) years.\",\n",
924
+ " \"Either party may assign this Agreement with the prior written consent of the other party.\",\n",
925
+ "]\n",
926
+ "\n",
927
+ "print(\"🧪 Testing model on sample clauses:\\n\")\n",
928
+ "for clause in test_clauses:\n",
929
+ " results = classifier(clause, truncation=True, max_length=MAX_LENGTH)\n",
930
+ " top = results[0] if isinstance(results[0], dict) else results[0][0]\n",
931
+ " top3 = results[:3] if isinstance(results[0], dict) else results[0][:3]\n",
932
+ " \n",
933
+ " print(f\"📄 \\\"{clause[:90]}{'...' if len(clause) > 90 else ''}\\\"\")\n",
934
+ " for r in top3:\n",
935
+ " score = r['score']\n",
936
+ " bar = '█' * int(score * 20)\n",
937
+ " print(f\" → {r['label']:40s} {score:.4f} {bar}\")\n",
938
+ " print()"
939
+ ],
940
+ "metadata": {},
941
+ "execution_count": null,
942
+ "outputs": []
943
+ },
944
+ {
945
+ "cell_type": "markdown",
946
+ "source": [
947
+ "## Step 14: Generate Updated app.py Integration Code\n",
948
+ "\n",
949
+ "Copy-paste this into your ClauseGuard Space's `app.py` to use the new model."
950
+ ],
951
+ "metadata": {}
952
+ },
953
+ {
954
+ "cell_type": "code",
955
+ "source": [
956
+ "integration_code = f'''\n",
957
+ "# ═══════════════════════════════════════════════════════════════\n",
958
+ "# ClauseGuard v4 — Integration Code\n",
959
+ "# Replace the model loading section in app.py with this:\n",
960
+ "# ═══════════════════════════════════════════════════════════════\n",
961
+ "\n",
962
+ "# OLD (remove these):\n",
963
+ "# base = \"nlpaueb/legal-bert-base-uncased\"\n",
964
+ "# adapter = \"Mokshith31/legalbert-contract-clause-classification\"\n",
965
+ "# from peft import PeftModel\n",
966
+ "\n",
967
+ "# NEW:\n",
968
+ "CLAUSEGUARD_MODEL = \"{HUB_MODEL_ID}\"\n",
969
+ "\n",
970
+ "def _load_cuad_model():\n",
971
+ " global cuad_tokenizer, cuad_model, _model_status\n",
972
+ " if not _HAS_TORCH:\n",
973
+ " _model_status[\"cuad\"] = \"unavailable\"\n",
974
+ " return\n",
975
+ " try:\n",
976
+ " print(f\"[ClauseGuard] Loading classifier: {{CLAUSEGUARD_MODEL}}\")\n",
977
+ " cuad_tokenizer = AutoTokenizer.from_pretrained(CLAUSEGUARD_MODEL)\n",
978
+ " cuad_model = AutoModelForSequenceClassification.from_pretrained(CLAUSEGUARD_MODEL)\n",
979
+ " cuad_model.eval()\n",
980
+ " _model_status[\"cuad\"] = \"loaded\"\n",
981
+ " print(f\"[ClauseGuard] Model loaded: {{sum(p.numel() for p in cuad_model.parameters()):,}} params\")\n",
982
+ " except Exception as e:\n",
983
+ " print(f\"[ClauseGuard] Model load failed: {{e}}\")\n",
984
+ " _model_status[\"cuad\"] = f\"failed: {{e}}\"\n",
985
+ "\n",
986
+ "# In classify_cuad(), change max_length:\n",
987
+ "# max_length=256 → max_length=512\n",
988
+ "#\n",
989
+ "# Also: since the new model is single-label (softmax),\n",
990
+ "# change the prediction logic from sigmoid to:\n",
991
+ "#\n",
992
+ "# probs = torch.softmax(logits, dim=-1)[0] # instead of sigmoid\n",
993
+ "# top_indices = torch.argsort(probs, descending=True)[:5]\n",
994
+ "# for i in top_indices:\n",
995
+ "# if float(probs[i]) > 0.10: # confidence threshold\n",
996
+ "# label = CUAD_LABELS[i]\n",
997
+ "# ...\n",
998
+ "\n",
999
+ "# No more PEFT dependency needed!\n",
1000
+ "# No more ignore_mismatched_sizes!\n",
1001
+ "# Just load directly — the model already has the correct head.\n",
1002
+ "'''\n",
1003
+ "\n",
1004
+ "print(integration_code)"
1005
+ ],
1006
+ "metadata": {},
1007
+ "execution_count": null,
1008
+ "outputs": []
1009
+ },
1010
+ {
1011
+ "cell_type": "markdown",
1012
+ "source": [
1013
+ "## Step 15: Comparison with Current Model\n",
1014
+ "\n",
1015
+ "| Metric | Current (Legal-BERT + LoRA) | New (DeBERTa-v3-large + ASL) |\n",
1016
+ "|--------|---------------------------|-----------------------------|\n",
1017
+ "| Base model | 110M params | 435M params |\n",
1018
+ "| Training | LoRA (frozen backbone) | Full fine-tune |\n",
1019
+ "| Pre-training | None | LEDGAR (60K, 100 classes) |\n",
1020
+ "| Max tokens | 256 | 512 |\n",
1021
+ "| Loss function | Cross-entropy | Asymmetric Loss |\n",
1022
+ "| Zero-F1 classes | 10 of 41 | TBD (should be much fewer) |\n",
1023
+ "| Macro-F1 | ~50% | Target: 78-87% |\n",
1024
+ "\n",
1025
+ "---\n",
1026
+ "\n",
1027
+ "## ✅ Done!\n",
1028
+ "\n",
1029
+ "Your trained model is at: **https://huggingface.co/gaurv007/clauseguard-deberta-v3-large**\n",
1030
+ "\n",
1031
+ "### Next Steps:\n",
1032
+ "1. Update ClauseGuard Space to use this model (see integration code above)\n",
1033
+ "2. Remove PEFT dependency from requirements.txt\n",
1034
+ "3. Consider training SetFit classifiers for any remaining zero-F1 classes\n",
1035
+ "4. Add OCR support (Feature #2)\n",
1036
+ "5. Add RAG chatbot (Feature #3)"
1037
+ ],
1038
+ "metadata": {}
1039
+ }
1040
+ ]
1041
+ }
ml/requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- transformers==5.6.1
2
  datasets>=3.2.0
3
  torch>=2.5.0
4
  scikit-learn>=1.6.0
5
  accelerate>=1.2.0
6
- optimum[onnxruntime]>=1.24.0
 
1
+ transformers>=5.6.0
2
  datasets>=3.2.0
3
  torch>=2.5.0
4
  scikit-learn>=1.6.0
5
  accelerate>=1.2.0
6
+ huggingface_hub>=0.27.0
ml/train_classifier_v4.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ClauseGuard v4 — 2-Stage DeBERTa-v3-large Training Script
3
+ ═══════════════════════════════════════════════════════════
4
+
5
+ Stage 1: Pre-fine-tune on LEDGAR (60K legal provisions, 100 classes)
6
+ Stage 2: Fine-tune on CUAD (41 classes) with Asymmetric Loss
7
+
8
+ Usage:
9
+ python train_classifier_v4.py # Full 2-stage pipeline
10
+ python train_classifier_v4.py --stage 1 # Stage 1 only
11
+ python train_classifier_v4.py --stage 2 --checkpoint ./stage1_ledgar_best # Stage 2 only
12
+
13
+ Requirements:
14
+ pip install transformers datasets scikit-learn accelerate torch
15
+
16
+ Hardware: A100 80GB recommended (~4-6 hours total)
17
+ """
18
+
19
+ import os
20
+ import gc
21
+ import argparse
22
+ import json
23
+ from collections import Counter
24
+ from datetime import datetime
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from datasets import load_dataset, Dataset
31
+ from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
32
+ from sklearn.model_selection import train_test_split
33
+ from transformers import (
34
+ AutoConfig,
35
+ AutoModelForSequenceClassification,
36
+ AutoTokenizer,
37
+ DataCollatorWithPadding,
38
+ Trainer,
39
+ TrainingArguments,
40
+ EarlyStoppingCallback,
41
+ )
42
+
43
+
44
+ # ═══════════════════════════════════════════════════════════════
45
+ # CONFIGURATION
46
+ # ═══════════════════════════════════════════════════════════════
47
+
48
+ BASE_MODEL = os.environ.get("BASE_MODEL", "microsoft/deberta-v3-large")
49
+ MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "512"))
50
+ HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-deberta-v3-large")
51
+ PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true"
52
+ SEED = 42
53
+
54
+ CUAD_LABELS = [
55
+ "Document Name", "Parties", "Agreement Date", "Effective Date",
56
+ "Expiration Date", "Renewal Term", "Notice Period to Terminate Renewal",
57
+ "Governing Law", "Most Favored Nation", "Non-Compete", "Exclusivity",
58
+ "No-Solicit of Customers", "No-Solicit of Employees", "Non-Disparagement",
59
+ "Termination for Convenience", "ROFR/ROFO/ROFN", "Change of Control",
60
+ "Anti-Assignment", "Revenue/Profit Sharing", "Price Restriction",
61
+ "Minimum Commitment", "Volume Restriction", "IP Ownership Assignment",
62
+ "Joint IP Ownership", "License Grant", "Non-Transferable License",
63
+ "Affiliate License-Licensor", "Affiliate License-Licensee",
64
+ "Unlimited/All-You-Can-Eat License", "Irrevocable or Perpetual License",
65
+ "Source Code Escrow", "Post-Termination Services", "Audit Rights",
66
+ "Uncapped Liability", "Cap on Liability", "Liquidated Damages",
67
+ "Warranty Duration", "Insurance", "Covenant Not to Sue",
68
+ "Third Party Beneficiary", "Other",
69
+ ]
70
+ NUM_CUAD_LABELS = len(CUAD_LABELS)
71
+
72
+
73
+ # ═══════════════════════════════════════════════════════════════
74
+ # ASYMMETRIC LOSS (arxiv:2009.14119)
75
+ # ═══════════════════════════════════════════════════════════════
76
+
77
+ class AsymmetricLoss(nn.Module):
78
+ """Focal-style loss with asymmetric gamma for class imbalance."""
79
+
80
+ def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, eps=1e-8,
81
+ class_weights=None):
82
+ super().__init__()
83
+ self.gamma_pos = gamma_pos
84
+ self.gamma_neg = gamma_neg
85
+ self.clip = clip
86
+ self.eps = eps
87
+ if class_weights is not None:
88
+ self.register_buffer('class_weights',
89
+ torch.tensor(class_weights, dtype=torch.float32))
90
+ else:
91
+ self.class_weights = None
92
+
93
+ def forward(self, logits, targets):
94
+ """Multi-class focal cross-entropy with class weights."""
95
+ if self.class_weights is not None:
96
+ ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights,
97
+ reduction='none')
98
+ else:
99
+ ce_loss = F.cross_entropy(logits, targets, reduction='none')
100
+
101
+ probs = F.softmax(logits, dim=-1)
102
+ p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
103
+ focal_weight = (1 - p_t) ** self.gamma_neg
104
+ loss = focal_weight * ce_loss
105
+ return loss.mean()
106
+
107
+
108
+ # ═══════════════════════════════════════════════════════════════
109
+ # CUSTOM TRAINER
110
+ # ═══════════════════════════════════════════════════════════════
111
+
112
+ class ASLTrainer(Trainer):
113
+ def __init__(self, *args, asl_loss_fn=None, **kwargs):
114
+ super().__init__(*args, **kwargs)
115
+ self.asl = asl_loss_fn
116
+
117
+ def compute_loss(self, model, inputs, return_outputs=False,
118
+ num_items_in_batch=None):
119
+ labels = inputs.pop("labels")
120
+ outputs = model(**inputs)
121
+ logits = outputs.logits
122
+ if self.asl is not None:
123
+ loss = self.asl(logits, labels)
124
+ else:
125
+ loss = F.cross_entropy(logits, labels)
126
+ return (loss, outputs) if return_outputs else loss
127
+
128
+
129
+ # ═══════════════════════════════════════════════════════════════
130
+ # METRICS
131
+ # ═══════════════════════════════════════════════════════════════
132
+
133
+ def compute_metrics(eval_pred):
134
+ logits, labels = eval_pred.predictions, eval_pred.label_ids
135
+ preds = np.argmax(logits, axis=-1)
136
+ return {
137
+ "accuracy": (preds == labels).mean(),
138
+ "micro_f1": f1_score(labels, preds, average="micro", zero_division=0),
139
+ "macro_f1": f1_score(labels, preds, average="macro", zero_division=0),
140
+ "weighted_f1": f1_score(labels, preds, average="weighted", zero_division=0),
141
+ }
142
+
143
+
144
+ # ═══════════════════════════════════════════════════════════════
145
+ # STAGE 1: LEDGAR
146
+ # ═══════════════════════════════════════════════════════════════
147
+
148
+ def run_stage1(tokenizer, output_dir="./stage1_ledgar_best"):
149
+ print("\n" + "=" * 60)
150
+ print(" STAGE 1: Pre-fine-tune on LEDGAR (100 classes)")
151
+ print("=" * 60)
152
+
153
+ ledgar = load_dataset("coastalcph/lex_glue", "ledgar")
154
+ num_labels = ledgar['train'].features['label'].num_classes
155
+ print(f" Train: {len(ledgar['train']):,} | Val: {len(ledgar['validation']):,}")
156
+ print(f" Classes: {num_labels}")
157
+
158
+ def preprocess(examples):
159
+ tok = tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH,
160
+ padding=False)
161
+ tok["labels"] = examples["label"]
162
+ return tok
163
+
164
+ tokenized = ledgar.map(preprocess, batched=True,
165
+ remove_columns=ledgar["train"].column_names)
166
+
167
+ model = AutoModelForSequenceClassification.from_pretrained(
168
+ BASE_MODEL, num_labels=num_labels,
169
+ problem_type="single_label_classification",
170
+ ignore_mismatched_sizes=True,
171
+ )
172
+ print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
173
+
174
+ args = TrainingArguments(
175
+ output_dir="./stage1_ledgar",
176
+ num_train_epochs=5,
177
+ per_device_train_batch_size=8,
178
+ per_device_eval_batch_size=16,
179
+ gradient_accumulation_steps=4,
180
+ learning_rate=2e-5,
181
+ weight_decay=0.06,
182
+ warmup_ratio=0.1,
183
+ lr_scheduler_type="cosine",
184
+ eval_strategy="epoch",
185
+ save_strategy="epoch",
186
+ save_total_limit=2,
187
+ load_best_model_at_end=True,
188
+ metric_for_best_model="macro_f1",
189
+ greater_is_better=True,
190
+ bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,
191
+ fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
192
+ logging_strategy="steps",
193
+ logging_steps=50,
194
+ logging_first_step=True,
195
+ disable_tqdm=True,
196
+ report_to="none",
197
+ dataloader_num_workers=2,
198
+ seed=SEED,
199
+ gradient_checkpointing=True,
200
+ )
201
+
202
+ trainer = Trainer(
203
+ model=model, args=args,
204
+ train_dataset=tokenized["train"],
205
+ eval_dataset=tokenized["validation"],
206
+ processing_class=tokenizer,
207
+ data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
208
+ compute_metrics=compute_metrics,
209
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
210
+ )
211
+
212
+ result = trainer.train()
213
+ print(f"\n Stage 1 training loss: {result.training_loss:.4f}")
214
+
215
+ test_metrics = trainer.evaluate(tokenized["test"])
216
+ print(f" Stage 1 test micro-F1: {test_metrics['eval_micro_f1']:.4f}")
217
+ print(f" Stage 1 test macro-F1: {test_metrics['eval_macro_f1']:.4f}")
218
+
219
+ trainer.save_model(output_dir)
220
+ tokenizer.save_pretrained(output_dir)
221
+ print(f" Saved to {output_dir}")
222
+
223
+ del model, trainer
224
+ torch.cuda.empty_cache()
225
+ gc.collect()
226
+
227
+ return output_dir
228
+
229
+
230
+ # ═════════════════════════════════════════════════════════��═════
231
+ # STAGE 2: CUAD
232
+ # ═══════════════════════════════════════════════════════════════
233
+
234
+ def run_stage2(tokenizer, checkpoint_path, output_dir="./clauseguard-deberta-final"):
235
+ print("\n" + "=" * 60)
236
+ print(f" STAGE 2: Fine-tune on CUAD ({NUM_CUAD_LABELS} classes) with ASL")
237
+ print("=" * 60)
238
+
239
+ # Load and split CUAD
240
+ cuad_raw = load_dataset(
241
+ "dvgodoy/CUAD_v1_Contract_Understanding_clause_classification",
242
+ split="train"
243
+ )
244
+ cuad_df = cuad_raw.to_pandas()
245
+
246
+ unique_files = cuad_df['file_name'].unique()
247
+ train_files, test_files = train_test_split(unique_files, test_size=0.2,
248
+ random_state=SEED)
249
+ val_files, test_files = train_test_split(test_files, test_size=0.5,
250
+ random_state=SEED)
251
+
252
+ splits = {
253
+ "train": Dataset.from_pandas(
254
+ cuad_df[cuad_df['file_name'].isin(train_files)].reset_index(drop=True)
255
+ ),
256
+ "val": Dataset.from_pandas(
257
+ cuad_df[cuad_df['file_name'].isin(val_files)].reset_index(drop=True)
258
+ ),
259
+ "test": Dataset.from_pandas(
260
+ cuad_df[cuad_df['file_name'].isin(test_files)].reset_index(drop=True)
261
+ ),
262
+ }
263
+
264
+ for name, ds in splits.items():
265
+ print(f" {name}: {len(ds)} rows")
266
+
267
+ def preprocess_cuad(examples):
268
+ tok = tokenizer(examples["clause"], truncation=True, max_length=MAX_LENGTH,
269
+ padding=False)
270
+ tok["labels"] = examples["class_id"]
271
+ return tok
272
+
273
+ tok_splits = {}
274
+ for name, ds in splits.items():
275
+ tok_splits[name] = ds.map(preprocess_cuad, batched=True,
276
+ remove_columns=ds.column_names)
277
+
278
+ # Load model from Stage 1 checkpoint
279
+ model = AutoModelForSequenceClassification.from_pretrained(
280
+ checkpoint_path,
281
+ num_labels=NUM_CUAD_LABELS,
282
+ ignore_mismatched_sizes=True,
283
+ problem_type="single_label_classification",
284
+ )
285
+
286
+ # Update label mapping
287
+ model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)}
288
+ model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)}
289
+
290
+ print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
291
+
292
+ # Compute class weights
293
+ train_counts = Counter(tok_splits["train"]["labels"])
294
+ total = sum(train_counts.values())
295
+ class_weights = []
296
+ for i in range(NUM_CUAD_LABELS):
297
+ count = train_counts.get(i, 1)
298
+ weight = min(10.0, total / (NUM_CUAD_LABELS * count))
299
+ class_weights.append(weight)
300
+
301
+ asl = AsymmetricLoss(gamma_pos=0, gamma_neg=4, clip=0.05,
302
+ class_weights=class_weights)
303
+ if torch.cuda.is_available():
304
+ asl = asl.cuda()
305
+
306
+ args = TrainingArguments(
307
+ output_dir="./stage2_cuad",
308
+ num_train_epochs=20,
309
+ per_device_train_batch_size=8,
310
+ per_device_eval_batch_size=16,
311
+ gradient_accumulation_steps=4,
312
+ learning_rate=1e-5,
313
+ weight_decay=0.06,
314
+ warmup_ratio=0.1,
315
+ lr_scheduler_type="cosine",
316
+ eval_strategy="epoch",
317
+ save_strategy="epoch",
318
+ save_total_limit=3,
319
+ load_best_model_at_end=True,
320
+ metric_for_best_model="macro_f1",
321
+ greater_is_better=True,
322
+ bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,
323
+ fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
324
+ logging_strategy="steps",
325
+ logging_steps=25,
326
+ logging_first_step=True,
327
+ disable_tqdm=True,
328
+ report_to="none",
329
+ push_to_hub=PUSH_TO_HUB,
330
+ hub_model_id=HUB_MODEL_ID if PUSH_TO_HUB else None,
331
+ dataloader_num_workers=2,
332
+ seed=SEED,
333
+ gradient_checkpointing=True,
334
+ )
335
+
336
+ trainer = ASLTrainer(
337
+ model=model, args=args,
338
+ asl_loss_fn=asl,
339
+ train_dataset=tok_splits["train"],
340
+ eval_dataset=tok_splits["val"],
341
+ processing_class=tokenizer,
342
+ data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
343
+ compute_metrics=compute_metrics,
344
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
345
+ )
346
+
347
+ result = trainer.train()
348
+ print(f"\n Stage 2 training loss: {result.training_loss:.4f}")
349
+
350
+ # Evaluate
351
+ test_metrics = trainer.evaluate(tok_splits["test"])
352
+ print(f"\n{'='*60}")
353
+ print(f" CUAD TEST RESULTS")
354
+ print(f"{'='*60}")
355
+ print(f" Accuracy: {test_metrics['eval_accuracy']:.4f}")
356
+ print(f" Micro-F1: {test_metrics['eval_micro_f1']:.4f}")
357
+ print(f" Macro-F1: {test_metrics['eval_macro_f1']:.4f}")
358
+ print(f" Weighted-F1: {test_metrics['eval_weighted_f1']:.4f}")
359
+
360
+ # Full report
361
+ preds_out = trainer.predict(tok_splits["test"])
362
+ preds = np.argmax(preds_out.predictions, axis=-1)
363
+ labels = preds_out.label_ids
364
+ present = sorted(set(labels) | set(preds))
365
+ names = [CUAD_LABELS[i] if i < len(CUAD_LABELS) else f"Class-{i}" for i in present]
366
+ print("\n" + classification_report(labels, preds, labels=present,
367
+ target_names=names, zero_division=0, digits=4))
368
+
369
+ # Save
370
+ trainer.save_model(output_dir)
371
+ tokenizer.save_pretrained(output_dir)
372
+
373
+ if PUSH_TO_HUB:
374
+ trainer.push_to_hub(
375
+ commit_message=(
376
+ f"ClauseGuard v4: DeBERTa-v3-large LEDGAR→CUAD + ASL | "
377
+ f"micro-F1={test_metrics['eval_micro_f1']:.4f} "
378
+ f"macro-F1={test_metrics['eval_macro_f1']:.4f}"
379
+ )
380
+ )
381
+ print(f"\n Pushed to https://huggingface.co/{HUB_MODEL_ID}")
382
+
383
+ # Save test results
384
+ results_path = os.path.join(output_dir, "test_results.json")
385
+ with open(results_path, "w") as f:
386
+ json.dump({
387
+ "model": HUB_MODEL_ID,
388
+ "base_model": BASE_MODEL,
389
+ "max_length": MAX_LENGTH,
390
+ "stage1_dataset": "coastalcph/lex_glue (ledgar)",
391
+ "stage2_dataset": "dvgodoy/CUAD_v1_Contract_Understanding_clause_classification",
392
+ "test_results": {k: float(v) for k, v in test_metrics.items()
393
+ if isinstance(v, (int, float))},
394
+ "timestamp": datetime.now().isoformat(),
395
+ }, f, indent=2)
396
+
397
+ return output_dir
398
+
399
+
400
+ # ═══════════════════════════════════════════════════════════════
401
+ # MAIN
402
+ # ═══════════════════════════════════════════════════════════════
403
+
404
+ def main():
405
+ parser = argparse.ArgumentParser(description="ClauseGuard v4 Training")
406
+ parser.add_argument("--stage", type=int, default=0,
407
+ help="Run specific stage (1 or 2). Default: both")
408
+ parser.add_argument("--checkpoint", type=str, default="./stage1_ledgar_best",
409
+ help="Stage 1 checkpoint path for Stage 2")
410
+ args = parser.parse_args()
411
+
412
+ print(f"🛡️ ClauseGuard v4 Training")
413
+ print(f" Model: {BASE_MODEL}")
414
+ print(f" Max length: {MAX_LENGTH}")
415
+ print(f" Hub: {HUB_MODEL_ID}")
416
+ if torch.cuda.is_available():
417
+ print(f" GPU: {torch.cuda.get_device_name(0)}")
418
+ print(f" VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
419
+
420
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
421
+
422
+ if args.stage in (0, 1):
423
+ checkpoint = run_stage1(tokenizer)
424
+ else:
425
+ checkpoint = args.checkpoint
426
+
427
+ if args.stage in (0, 2):
428
+ run_stage2(tokenizer, checkpoint)
429
+
430
+ print("\n✅ Training complete!")
431
+
432
+
433
+ if __name__ == "__main__":
434
+ main()
obligations.py CHANGED
@@ -120,18 +120,22 @@ def extract_obligations(text):
120
  if not found_types:
121
  continue
122
 
123
- # Extract party
124
  party = "Unknown"
125
- for pp in PARTY_PATTERNS:
126
- m = re.search(pp, sentence)
127
- if m:
128
- party = m.group(0).strip()
129
- break
130
-
131
- # Try to determine which party has the obligation based on sentence structure
132
  obligation_direction = _detect_obligation_direction(sentence)
133
  if obligation_direction:
134
  party = obligation_direction
 
 
 
 
 
 
 
 
 
 
135
 
136
  # Extract timeframe
137
  deadline = "Not specified"
 
120
  if not found_types:
121
  continue
122
 
123
+ # Extract party (Fix 8: scope to sentence only, reject >40 char strings)
124
  party = "Unknown"
125
+ # First try structured direction detection
 
 
 
 
 
 
126
  obligation_direction = _detect_obligation_direction(sentence)
127
  if obligation_direction:
128
  party = obligation_direction
129
+ else:
130
+ # Fallback to pattern matching within the sentence
131
+ for pp in PARTY_PATTERNS:
132
+ m = re.search(pp, sentence)
133
+ if m:
134
+ candidate = m.group(0).strip()
135
+ # Fix 8: Reject party strings >40 chars (header bleed-through)
136
+ if len(candidate) <= 40:
137
+ party = candidate
138
+ break
139
 
140
  # Extract timeframe
141
  deadline = "Not specified"
ocr_engine.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ClauseGuard — OCR Engine v1.0
3
+ ═════════════════════════════
4
+ Smart PDF Router: detects native vs scanned PDFs.
5
+ • Native PDF → pdfplumber (fast, existing)
6
+ • Scanned PDF → docTR OCR (CPU-friendly, ~150MB models)
7
+
8
+ Architecture:
9
+ PDF uploaded
10
+
11
+ [detect_if_scanned] — pdfplumber gets <50 chars/page?
12
+ ↓ ↓
13
+ Native PDF Scanned PDF
14
+ ↓ ↓
15
+ pdfplumber docTR OCR (CPU)
16
+ ↓ ↓
17
+ Contract text → existing analysis pipeline
18
+ """
19
+
20
+ import os
21
+ import re
22
+
23
+ # ── docTR (soft-fail) ───────────────────────────────────────────────
24
+ _HAS_DOCTR = False
25
+ _ocr_predictor = None
26
+
27
+ try:
28
+ from doctr.io import DocumentFile
29
+ from doctr.models import ocr_predictor as _make_predictor
30
+ _HAS_DOCTR = True
31
+ except ImportError:
32
+ pass
33
+
34
+ # ── pdfplumber (soft-fail) ──────────────────────────────────────────
35
+ try:
36
+ import pdfplumber
37
+ _HAS_PDF = True
38
+ except ImportError:
39
+ _HAS_PDF = False
40
+
41
+ # ═══════════════════════════════════════════════════════════════════════
42
+ # OCR MODEL LOADING
43
+ # ═══════════════════════════════════════════════════════════════════════
44
+
45
+ _ocr_status = "not_loaded"
46
+
47
+ def _load_ocr_model():
48
+ """Load docTR OCR predictor (lazy, on first use)."""
49
+ global _ocr_predictor, _ocr_status
50
+ if _ocr_predictor is not None:
51
+ return _ocr_predictor
52
+ if not _HAS_DOCTR:
53
+ _ocr_status = "unavailable (python-doctr not installed)"
54
+ return None
55
+ try:
56
+ print("[ClauseGuard OCR] Loading docTR models (fast_base + crnn_vgg16_bn)...")
57
+ _ocr_predictor = _make_predictor(
58
+ det_arch="fast_base",
59
+ reco_arch="crnn_vgg16_bn",
60
+ pretrained=True,
61
+ assume_straight_pages=True,
62
+ )
63
+ _ocr_status = "loaded"
64
+ print("[ClauseGuard OCR] docTR models loaded successfully")
65
+ return _ocr_predictor
66
+ except Exception as e:
67
+ _ocr_status = f"failed: {e}"
68
+ print(f"[ClauseGuard OCR] docTR load failed: {e}")
69
+ return None
70
+
71
+
72
+ def get_ocr_status():
73
+ """Return human-readable OCR engine status."""
74
+ if _ocr_predictor is not None:
75
+ return "✅ OCR: docTR loaded"
76
+ elif _HAS_DOCTR:
77
+ return "⏳ OCR: docTR available (not yet loaded)"
78
+ else:
79
+ return "❌ OCR: unavailable (python-doctr not installed)"
80
+
81
+
82
+ # ═══════════════════════════════════════════════════════════════════════
83
+ # SMART PDF ROUTER
84
+ # ═══════════════════════════════════════════════════════════════════════
85
+
86
+ def _is_scanned_pdf(file_path, min_chars_per_page=50):
87
+ """
88
+ Detect if a PDF is scanned (image-based) by checking if pdfplumber
89
+ extracts fewer than `min_chars_per_page` characters on average.
90
+ """
91
+ if not _HAS_PDF:
92
+ return True # Can't check with pdfplumber, assume scanned
93
+ try:
94
+ with pdfplumber.open(file_path) as pdf:
95
+ if len(pdf.pages) == 0:
96
+ return True
97
+ total_chars = 0
98
+ pages_checked = min(len(pdf.pages), 5) # Check first 5 pages
99
+ for i in range(pages_checked):
100
+ page_text = pdf.pages[i].extract_text() or ""
101
+ total_chars += len(page_text.strip())
102
+ avg_chars = total_chars / pages_checked
103
+ return avg_chars < min_chars_per_page
104
+ except Exception:
105
+ return True # If pdfplumber fails, try OCR
106
+
107
+
108
+ def _extract_native_pdf(file_path):
109
+ """Extract text from a native (digital) PDF using pdfplumber."""
110
+ if not _HAS_PDF:
111
+ return None, "pdfplumber not installed"
112
+ try:
113
+ text = ""
114
+ with pdfplumber.open(file_path) as pdf:
115
+ for page in pdf.pages:
116
+ page_text = page.extract_text()
117
+ if page_text:
118
+ text += page_text + "\n\n"
119
+ if not text.strip():
120
+ return None, "No text extracted from PDF"
121
+ return text.strip(), None
122
+ except Exception as e:
123
+ return None, f"PDF parse error: {e}"
124
+
125
+
126
+ def _extract_scanned_pdf(file_path):
127
+ """Extract text from a scanned PDF using docTR OCR."""
128
+ predictor = _load_ocr_model()
129
+ if predictor is None:
130
+ return None, (
131
+ "OCR is not available. Install python-doctr: "
132
+ "`pip install python-doctr[torch]`"
133
+ )
134
+ try:
135
+ doc = DocumentFile.from_pdf(file_path)
136
+ result = predictor(doc)
137
+
138
+ # Extract text page by page
139
+ full_text = ""
140
+ for page_idx, page in enumerate(result.pages):
141
+ page_text = ""
142
+ for block in page.blocks:
143
+ for line in block.lines:
144
+ line_text = " ".join(word.value for word in line.words)
145
+ page_text += line_text + "\n"
146
+ page_text += "\n"
147
+ full_text += page_text + "\n\n"
148
+
149
+ if not full_text.strip():
150
+ return None, "OCR could not extract text from scanned PDF"
151
+
152
+ # Clean up OCR artifacts
153
+ full_text = _clean_ocr_text(full_text)
154
+ return full_text.strip(), None
155
+ except Exception as e:
156
+ return None, f"OCR error: {e}"
157
+
158
+
159
+ def _clean_ocr_text(text):
160
+ """Clean common OCR artifacts."""
161
+ # Remove excessive whitespace
162
+ text = re.sub(r'[ \t]{3,}', ' ', text)
163
+ # Fix common OCR substitutions
164
+ text = re.sub(r'\bl\b(?=[A-Z])', 'I', text) # l before capital → I
165
+ # Normalize line breaks
166
+ text = re.sub(r'\n{4,}', '\n\n\n', text)
167
+ # Remove single-char lines (OCR noise)
168
+ lines = text.split('\n')
169
+ cleaned_lines = []
170
+ for line in lines:
171
+ stripped = line.strip()
172
+ if len(stripped) <= 1 and stripped not in ('', '.', ',', ';'):
173
+ continue
174
+ cleaned_lines.append(line)
175
+ return '\n'.join(cleaned_lines)
176
+
177
+
178
+ # ═══════════════════════════════════════════════════════════════════════
179
+ # PUBLIC API
180
+ # ═══════════════════════════════════════════════════════════════════════
181
+
182
+ def parse_pdf_smart(file_path):
183
+ """
184
+ Smart PDF parser with OCR fallback.
185
+
186
+ Returns: (text, error, method)
187
+ text: extracted text (or None)
188
+ error: error message (or None)
189
+ method: "native" | "ocr" | None
190
+ """
191
+ if not os.path.exists(file_path):
192
+ return None, "File not found", None
193
+
194
+ # Step 1: Check if PDF is scanned
195
+ is_scanned = _is_scanned_pdf(file_path)
196
+
197
+ if not is_scanned:
198
+ # Step 2a: Native PDF — use pdfplumber
199
+ text, error = _extract_native_pdf(file_path)
200
+ if text:
201
+ return text, None, "native"
202
+ # If pdfplumber returns empty, fall through to OCR
203
+ print("[ClauseGuard OCR] pdfplumber returned empty — falling back to OCR")
204
+
205
+ # Step 2b: Scanned PDF or pdfplumber failed — use OCR
206
+ print(f"[ClauseGuard OCR] {'Scanned' if is_scanned else 'Empty native'} PDF detected — running docTR OCR...")
207
+ text, error = _extract_scanned_pdf(file_path)
208
+ if text:
209
+ return text, None, "ocr"
210
+ return None, error, None
211
+
212
+
213
+ def ocr_extract(file_path):
214
+ """
215
+ Force OCR extraction on a PDF (bypass native text check).
216
+ Useful when user explicitly wants OCR.
217
+ """
218
+ return _extract_scanned_pdf(file_path)
redlining.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ClauseGuard — Clause Redlining Engine v1.0
3
+ ═══════════════════════════════════════════
4
+ 3-Tier Hybrid Architecture:
5
+ Tier 1 — Template lookup (instant, zero hallucination risk)
6
+ Tier 2 — RAG retrieval from clause corpus (find fairer precedents)
7
+ Tier 3 — LLM refinement (adapt template using retrieved precedents)
8
+
9
+ Anti-hallucination guardrails:
10
+ • Template anchor: LLM can only refine, not generate from scratch
11
+ • RAG grounding: Retrieved precedents constrain the output space
12
+ • Disclaimer: "Not legal advice. Consult an attorney before executing."
13
+ • Legal citation: Prompt requires LLM to cite the consumer protection standard applied
14
+ """
15
+
16
+ import os
17
+ import re
18
+ from collections import defaultdict
19
+
20
+ # ── HF Inference Client (soft-fail) ─────────────────────────────────
21
+ _HAS_INFERENCE = False
22
+ try:
23
+ from huggingface_hub import InferenceClient
24
+ _HAS_INFERENCE = True
25
+ except ImportError:
26
+ pass
27
+
28
+ # ═══════════════════════════════════════════════════════════════════════
29
+ # TIER 1: TEMPLATE LIBRARY (18+ clause types)
30
+ # ═══════════════════════════════════════════════════════════════════════
31
+ # Based on FTC guidelines, EU Directive 93/13, and CFPB guidance.
32
+
33
+ SAFE_ALTERNATIVES = {
34
+ # ── CRITICAL Risk Clauses ──────────────────────────────────────
35
+ "Uncapped Liability": {
36
+ "risky_pattern": "Total liability shall not exceed $1 / unlimited liability exposure",
37
+ "safe_alternative": (
38
+ "Provider's aggregate liability under this Agreement shall not exceed the total "
39
+ "fees paid by the Customer in the twelve (12) months preceding the claim. "
40
+ "This limitation shall not apply to: (a) gross negligence or willful misconduct, "
41
+ "(b) breach of confidentiality obligations, (c) intellectual property indemnification "
42
+ "obligations, or (d) violations of applicable law."
43
+ ),
44
+ "legal_basis": "UCC § 2-719; Restatement (Second) of Contracts § 356",
45
+ "consumer_standard": "FTC guidelines on unconscionable contract terms",
46
+ "risk_level": "CRITICAL",
47
+ },
48
+ "Arbitration": {
49
+ "risky_pattern": "All disputes via binding arbitration / class action waiver",
50
+ "safe_alternative": (
51
+ "Disputes involving claims under [Dollar Amount] shall be resolved in small claims "
52
+ "court in the consumer's jurisdiction of residence. For other disputes, either party "
53
+ "may elect binding arbitration under [AAA/JAMS] rules. The consumer may opt out of "
54
+ "arbitration by providing written notice within thirty (30) days of accepting these "
55
+ "terms. Each party bears its own arbitration costs; the prevailing party may recover "
56
+ "reasonable attorney's fees."
57
+ ),
58
+ "legal_basis": "Federal Arbitration Act § 2; AT&T Mobility v. Concepcion, 563 U.S. 333 (2011)",
59
+ "consumer_standard": "CFPB Arbitration Rule guidance; EU Directive 93/13/EEC Art. 3",
60
+ "risk_level": "CRITICAL",
61
+ },
62
+ "IP Ownership Assignment": {
63
+ "risky_pattern": "All IP rights assigned to company / work-for-hire everything",
64
+ "safe_alternative": (
65
+ "Intellectual property created by the Receiving Party specifically in performance of "
66
+ "this Agreement ('Work Product IP') shall be assigned to the Disclosing Party. "
67
+ "Pre-existing IP and general knowledge, skills, and experience of the Receiving Party "
68
+ "remain the Receiving Party's property. The Disclosing Party grants the Receiving Party "
69
+ "a non-exclusive, perpetual license to use Work Product IP for internal portfolio and "
70
+ "reference purposes."
71
+ ),
72
+ "legal_basis": "17 U.S.C. § 101 (work for hire); Copyright Act § 201(b)",
73
+ "consumer_standard": "Standard IP assignment with carve-outs for pre-existing IP",
74
+ "risk_level": "CRITICAL",
75
+ },
76
+ "Termination for Convenience": {
77
+ "risky_pattern": "Terminate at any time without notice",
78
+ "safe_alternative": (
79
+ "Either party may terminate this Agreement for convenience upon thirty (30) days' "
80
+ "prior written notice. Immediate termination is permitted only for material breach "
81
+ "that remains uncured after a ten (10) day cure period following written notice "
82
+ "specifying the breach. Upon termination: (a) all outstanding fees become due, "
83
+ "(b) each party shall return or destroy confidential information within fifteen (15) "
84
+ "business days, and (c) licenses granted hereunder shall terminate except as "
85
+ "expressly stated to survive."
86
+ ),
87
+ "legal_basis": "Restatement (Second) of Contracts § 237; UCC § 2-309",
88
+ "consumer_standard": "FTC: adequate notice period required for service termination",
89
+ "risk_level": "CRITICAL",
90
+ },
91
+ "Limitation of liability": {
92
+ "risky_pattern": "Company not liable for any damages / complete disclaimer",
93
+ "safe_alternative": (
94
+ "Neither party shall be liable for indirect, incidental, special, or consequential "
95
+ "damages, EXCEPT in cases of: (a) gross negligence or willful misconduct, "
96
+ "(b) breach of confidentiality, (c) data breach involving personal information, or "
97
+ "(d) intellectual property infringement. Direct damages are limited to fees paid "
98
+ "in the prior twelve (12) months. Nothing in this Agreement limits liability for "
99
+ "death or personal injury caused by negligence."
100
+ ),
101
+ "legal_basis": "UCC § 2-719(3); EU Directive 93/13/EEC Annex (a)",
102
+ "consumer_standard": "Cannot exclude liability for death/personal injury (EU/UK law)",
103
+ "risk_level": "CRITICAL",
104
+ },
105
+ "Unilateral termination": {
106
+ "risky_pattern": "Company can terminate account at any time without reason",
107
+ "safe_alternative": (
108
+ "The Provider may suspend or terminate the User's account for: (a) material breach "
109
+ "of these Terms, (b) non-payment after ten (10) days' notice, (c) illegal activity, "
110
+ "or (d) extended inactivity exceeding twelve (12) months. The Provider shall provide "
111
+ "at least thirty (30) days' written notice before termination, except in cases of "
112
+ "illegal activity. Upon termination, the User shall have thirty (30) days to export "
113
+ "their data."
114
+ ),
115
+ "legal_basis": "EU Directive 2019/770 (Digital Content); CFPB guidance",
116
+ "consumer_standard": "Right to export data upon termination; adequate notice period",
117
+ "risk_level": "CRITICAL",
118
+ },
119
+ "Liquidated Damages": {
120
+ "risky_pattern": "Pre-determined damages far exceeding actual harm",
121
+ "safe_alternative": (
122
+ "In the event of breach, the non-breaching party shall be entitled to liquidated "
123
+ "damages in the amount of [specific reasonable amount], which the parties agree "
124
+ "represents a reasonable estimate of anticipated harm. This liquidated damages "
125
+ "provision shall not apply if actual damages are readily ascertainable, in which "
126
+ "case the non-breaching party may recover actual damages proven."
127
+ ),
128
+ "legal_basis": "Restatement (Second) of Contracts § 356; UCC § 2-718",
129
+ "consumer_standard": "Liquidated damages must be reasonable estimate, not penalty",
130
+ "risk_level": "CRITICAL",
131
+ },
132
+
133
+ # ── HIGH Risk Clauses ──────────────────────────────────────────
134
+ "Unilateral change": {
135
+ "risky_pattern": "We may modify terms at any time without notice",
136
+ "safe_alternative": (
137
+ "Material changes to these Terms require thirty (30) days' advance written notice "
138
+ "to the User via email and in-app notification. The User has the right to terminate "
139
+ "without penalty within the notice period if they do not accept the changes. "
140
+ "Non-material changes (e.g., formatting, clarifications) may be made without notice."
141
+ ),
142
+ "legal_basis": "EU Directive 93/13/EEC Art. 3; Restatement (Second) § 89",
143
+ "consumer_standard": "FTC: material changes require notice and right to reject",
144
+ "risk_level": "HIGH",
145
+ },
146
+ "Content removal": {
147
+ "risky_pattern": "Company can delete content at sole discretion without notice",
148
+ "safe_alternative": (
149
+ "Content may be removed only for violation of these Terms of Service, applicable law, "
150
+ "or valid legal process. The Provider shall provide prior notice specifying the reason "
151
+ "for removal (except where legally prohibited). The User has the right to appeal "
152
+ "within fourteen (14) days. Removed content shall be preserved for thirty (30) days "
153
+ "to allow for appeal resolution."
154
+ ),
155
+ "legal_basis": "EU Digital Services Act Art. 17; First Amendment considerations",
156
+ "consumer_standard": "Due process: notice, reason, and right to appeal",
157
+ "risk_level": "HIGH",
158
+ },
159
+ "Non-Compete": {
160
+ "risky_pattern": "Broad non-compete with no time/geography limits",
161
+ "safe_alternative": (
162
+ "During the term of this Agreement and for a period of [6-12] months thereafter, "
163
+ "the Receiving Party shall not directly compete with the Disclosing Party in "
164
+ "[specific market/geography]. This restriction applies only to [specific business "
165
+ "activities] and does not prevent general employment in the industry. The Disclosing "
166
+ "Party shall provide [garden leave pay / consideration] during the restricted period."
167
+ ),
168
+ "legal_basis": "Restatement (Second) of Contracts § 188; FTC Non-Compete Rule (2024)",
169
+ "consumer_standard": "Reasonable scope, duration, geography; adequate consideration",
170
+ "risk_level": "HIGH",
171
+ },
172
+ "Exclusivity": {
173
+ "risky_pattern": "Exclusive dealing with no time limit or exit clause",
174
+ "safe_alternative": (
175
+ "The exclusivity arrangement shall apply for an initial term of [12-24] months, "
176
+ "after which either party may convert to non-exclusive upon sixty (60) days' notice. "
177
+ "Exclusivity is limited to [specific product/service category] and [specific "
178
+ "geographic area]. Performance benchmarks shall be reviewed quarterly; failure to "
179
+ "meet agreed minimums allows termination of exclusivity."
180
+ ),
181
+ "legal_basis": "Sherman Act § 1; EU Competition Law Art. 101 TFEU",
182
+ "consumer_standard": "Time-limited, scope-limited, with performance exit clause",
183
+ "risk_level": "HIGH",
184
+ },
185
+ "Anti-Assignment": {
186
+ "risky_pattern": "Complete prohibition on assignment without consent",
187
+ "safe_alternative": (
188
+ "Neither party may assign this Agreement without the prior written consent of the "
189
+ "other party, which shall not be unreasonably withheld, conditioned, or delayed. "
190
+ "Notwithstanding the foregoing, either party may assign this Agreement without "
191
+ "consent in connection with a merger, acquisition, or sale of substantially all "
192
+ "of its assets, provided the assignee assumes all obligations hereunder."
193
+ ),
194
+ "legal_basis": "UCC § 2-210; Restatement (Second) of Contracts § 317",
195
+ "consumer_standard": "Consent not to be unreasonably withheld; M&A carve-out",
196
+ "risk_level": "HIGH",
197
+ },
198
+
199
+ # ── MEDIUM Risk Clauses ────────────────────────────────────────
200
+ "Jurisdiction": {
201
+ "risky_pattern": "Exclusive jurisdiction in distant/foreign state",
202
+ "safe_alternative": (
203
+ "The Consumer may bring claims in their jurisdiction of residence or the Provider's "
204
+ "principal place of business. Small claims actions may be brought in any court of "
205
+ "competent jurisdiction. For commercial contracts: disputes shall be resolved in "
206
+ "[mutually agreed location] or the defendant's principal place of business."
207
+ ),
208
+ "legal_basis": "EU Regulation 1215/2012 (Brussels I); CJEU C-585/08",
209
+ "consumer_standard": "Consumer may sue in home jurisdiction (EU Directive 93/13)",
210
+ "risk_level": "MEDIUM",
211
+ },
212
+ "Choice of law": {
213
+ "risky_pattern": "Governed by laws of a jurisdiction that disadvantages consumer",
214
+ "safe_alternative": (
215
+ "This Agreement shall be governed by the laws of [State/Country]. Notwithstanding "
216
+ "the foregoing, nothing in this choice of law provision shall deprive the Consumer "
217
+ "of the protection afforded by mandatory provisions of the law of the Consumer's "
218
+ "habitual residence."
219
+ ),
220
+ "legal_basis": "EU Regulation 593/2008 (Rome I) Art. 6; UCC § 1-301",
221
+ "consumer_standard": "Cannot override mandatory consumer protection of home jurisdiction",
222
+ "risk_level": "MEDIUM",
223
+ },
224
+ "Contract by using": {
225
+ "risky_pattern": "Bound to contract by merely using the service (browsewrap)",
226
+ "safe_alternative": (
227
+ "By creating an account, the User acknowledges they have read, understood, and agree "
228
+ "to be bound by these Terms. The User must affirmatively accept these Terms via "
229
+ "checkbox or click-through before account creation. Continued use after material "
230
+ "changes requires re-acceptance."
231
+ ),
232
+ "legal_basis": "Specht v. Netscape, 306 F.3d 17 (2d Cir. 2002)",
233
+ "consumer_standard": "Clickwrap > browsewrap; affirmative acceptance required",
234
+ "risk_level": "MEDIUM",
235
+ },
236
+
237
+ # ── Additional Common Clauses ──────────────────────────────────
238
+ "Auto-Renewal": {
239
+ "risky_pattern": "Auto-renews silently without notice",
240
+ "safe_alternative": (
241
+ "This Agreement shall automatically renew for successive [term] periods unless "
242
+ "either party provides written notice of non-renewal at least thirty (30) days "
243
+ "before the end of the then-current term. The Provider shall send a reminder "
244
+ "notice thirty (30) to sixty (60) days before renewal. The Consumer may cancel "
245
+ "within fifteen (15) days of renewal for a pro-rated refund."
246
+ ),
247
+ "legal_basis": "California Auto-Renewal Law (ARL) Bus. & Prof. Code § 17600; FTC Negative Option Rule",
248
+ "consumer_standard": "Reminder notice required; easy cancellation; pro-rated refund",
249
+ "risk_level": "HIGH",
250
+ },
251
+ "Indemnification": {
252
+ "risky_pattern": "User indemnifies company for all claims without limit",
253
+ "safe_alternative": (
254
+ "Each party shall indemnify, defend, and hold harmless the other party from "
255
+ "third-party claims arising from: (a) the indemnifying party's breach of this "
256
+ "Agreement, (b) the indemnifying party's negligence or willful misconduct, or "
257
+ "(c) the indemnifying party's violation of applicable law. The User's indemnification "
258
+ "obligation is limited to claims arising from the User's own negligence or "
259
+ "intentional acts. The maximum indemnification obligation shall not exceed [amount]."
260
+ ),
261
+ "legal_basis": "Restatement (Second) of Contracts § 345; UCC § 2-607",
262
+ "consumer_standard": "Mutual indemnification; limited to own acts; capped",
263
+ "risk_level": "HIGH",
264
+ },
265
+ "Confidentiality": {
266
+ "risky_pattern": "Overly broad confidentiality with no exceptions or time limit",
267
+ "safe_alternative": (
268
+ "Each party agrees to maintain the confidentiality of the other's Confidential "
269
+ "Information for a period of [3-5] years from disclosure. Confidential Information "
270
+ "excludes: (a) publicly available information, (b) independently developed "
271
+ "information, (c) information received from a third party without restriction, "
272
+ "(d) information required to be disclosed by law or court order (with prompt notice "
273
+ "to the disclosing party)."
274
+ ),
275
+ "legal_basis": "Restatement (Third) of Unfair Competition § 39-45",
276
+ "consumer_standard": "Time-limited; standard exceptions; required disclosure carve-out",
277
+ "risk_level": "MEDIUM",
278
+ },
279
+ }
280
+
281
+ # Mapping from CUAD/unfair labels to our template keys
282
+ _LABEL_TO_TEMPLATE = {
283
+ "Uncapped Liability": "Uncapped Liability",
284
+ "Arbitration": "Arbitration",
285
+ "IP Ownership Assignment": "IP Ownership Assignment",
286
+ "Termination for Convenience": "Termination for Convenience",
287
+ "Limitation of liability": "Limitation of liability",
288
+ "Unilateral termination": "Unilateral termination",
289
+ "Liquidated Damages": "Liquidated Damages",
290
+ "Unilateral change": "Unilateral change",
291
+ "Content removal": "Content removal",
292
+ "Non-Compete": "Non-Compete",
293
+ "Exclusivity": "Exclusivity",
294
+ "Anti-Assignment": "Anti-Assignment",
295
+ "Jurisdiction": "Jurisdiction",
296
+ "Choice of law": "Choice of law",
297
+ "Contract by using": "Contract by using",
298
+ "Cap on Liability": "Limitation of liability", # Similar enough
299
+ "No-Solicit of Customers": "Non-Compete", # Use non-compete template
300
+ "No-Solicit of Employees": "Non-Compete",
301
+ "Non-Disparagement": "Confidentiality", # Similar restrictive clause
302
+ }
303
+
304
+
305
+ # ═══════════════════════════════════════════════════════════════════════
306
+ # TIER 2: RAG RETRIEVAL (find fairer precedent clauses)
307
+ # ═══════════════════════════════════════════════════════════════════════
308
+
309
+ def _find_similar_templates(clause_label, clause_text):
310
+ """
311
+ Find the most relevant safe alternative template(s) for a given clause.
312
+ Returns list of matching templates.
313
+ """
314
+ matches = []
315
+
316
+ # Direct label match
317
+ template_key = _LABEL_TO_TEMPLATE.get(clause_label)
318
+ if template_key and template_key in SAFE_ALTERNATIVES:
319
+ matches.append((template_key, SAFE_ALTERNATIVES[template_key], 1.0))
320
+
321
+ # Also do keyword matching for clauses that might not have exact label matches
322
+ clause_lower = clause_text.lower()
323
+ keyword_map = {
324
+ "Uncapped Liability": ["unlimited liability", "uncapped", "no limit on liability"],
325
+ "Arbitration": ["arbitration", "arbitrate", "waive right to court", "class action waiver"],
326
+ "Termination for Convenience": ["terminate at any time", "terminate without cause", "terminate without notice"],
327
+ "Limitation of liability": ["not liable", "limitation of liability", "in no event", "disclaim"],
328
+ "Unilateral change": ["modify at any time", "sole discretion", "change terms", "without notice"],
329
+ "Content removal": ["remove content", "delete content", "remove at sole discretion"],
330
+ "Auto-Renewal": ["auto-renew", "automatically renew", "automatic renewal"],
331
+ "Indemnification": ["indemnif", "hold harmless"],
332
+ }
333
+
334
+ for key, keywords in keyword_map.items():
335
+ if key in SAFE_ALTERNATIVES:
336
+ for kw in keywords:
337
+ if kw in clause_lower:
338
+ # Avoid duplicates
339
+ if not any(m[0] == key for m in matches):
340
+ matches.append((key, SAFE_ALTERNATIVES[key], 0.7))
341
+ break
342
+
343
+ return matches
344
+
345
+
346
+ # ═══════════════════════════════════════════════════════════════════════
347
+ # TIER 3: LLM REFINEMENT
348
+ # ═══════════════════════════════════════════════════════════════════════
349
+
350
+ _LLM_MODEL = "Qwen/Qwen2.5-7B-Instruct"
351
+
352
+ def _refine_with_llm(original_clause, template, clause_label):
353
+ """
354
+ Use LLM to adapt the template to the specific clause context.
355
+ The LLM refines — it does NOT generate from scratch (anti-hallucination).
356
+ """
357
+ if not _HAS_INFERENCE:
358
+ return None
359
+
360
+ try:
361
+ token = os.environ.get("HF_TOKEN", "")
362
+ client = InferenceClient(
363
+ provider="hf-inference",
364
+ api_key=token if token else None,
365
+ )
366
+
367
+ prompt = f"""You are a legal contract redlining assistant. Your task is to adapt a safe clause template to fit the specific context of an original risky clause.
368
+
369
+ RULES:
370
+ 1. You MUST use the provided template as your base — do NOT generate clauses from scratch.
371
+ 2. Preserve the legal protections in the template.
372
+ 3. Adapt specific details (parties, amounts, timeframes) from the original clause.
373
+ 4. Keep the same legal standard cited in the template.
374
+ 5. Output ONLY the refined clause text, nothing else.
375
+ 6. The refined clause should be immediately usable in a contract.
376
+
377
+ ORIGINAL RISKY CLAUSE:
378
+ {original_clause[:500]}
379
+
380
+ CLAUSE TYPE: {clause_label}
381
+
382
+ SAFE TEMPLATE:
383
+ {template['safe_alternative']}
384
+
385
+ LEGAL BASIS: {template['legal_basis']}
386
+
387
+ Write the refined safer clause (adapt the template to this specific contract's context):"""
388
+
389
+ response = client.chat_completion(
390
+ model=_LLM_MODEL,
391
+ messages=[
392
+ {"role": "system", "content": "You are a legal contract redlining expert. Output ONLY the refined clause text."},
393
+ {"role": "user", "content": prompt},
394
+ ],
395
+ max_tokens=512,
396
+ temperature=0.2,
397
+ )
398
+ refined = response.choices[0].message.content.strip()
399
+
400
+ # Sanity check: refined should be substantial
401
+ if len(refined) < 50:
402
+ return None
403
+ return refined
404
+
405
+ except Exception as e:
406
+ print(f"[ClauseGuard Redline] LLM refinement error: {e}")
407
+ return None
408
+
409
+
410
+ # ═══════════════════════════════════════════════════════════════════════
411
+ # PUBLIC API
412
+ # ═══════════════════════════════════════════════════════════════════════
413
+
414
+ def generate_redlines(analysis_result, use_llm=True):
415
+ """
416
+ Generate redline suggestions for all flagged clauses in the analysis.
417
+
418
+ Returns list of redline suggestions:
419
+ [{
420
+ "original_text": str,
421
+ "clause_label": str,
422
+ "risk_level": str,
423
+ "safe_alternative": str,
424
+ "legal_basis": str,
425
+ "consumer_standard": str,
426
+ "tier": "template" | "llm_refined",
427
+ "confidence": str,
428
+ }]
429
+ """
430
+ if analysis_result is None:
431
+ return []
432
+
433
+ clauses = analysis_result.get("clauses", [])
434
+ if not clauses:
435
+ return []
436
+
437
+ redlines = []
438
+ seen_labels = set() # Deduplicate by label
439
+
440
+ # Sort by risk level: CRITICAL first
441
+ risk_order = {"CRITICAL": 0, "HIGH": 1, "MEDIUM": 2, "LOW": 3}
442
+ sorted_clauses = sorted(clauses, key=lambda c: risk_order.get(c.get("risk", "LOW"), 3))
443
+
444
+ for clause in sorted_clauses:
445
+ label = clause.get("label", "")
446
+ risk = clause.get("risk", "LOW")
447
+ text = clause.get("text", "")
448
+
449
+ # Skip LOW risk and already-seen labels
450
+ if risk == "LOW" or label in seen_labels:
451
+ continue
452
+ seen_labels.add(label)
453
+
454
+ # Find matching templates (Tier 1 + Tier 2)
455
+ matches = _find_similar_templates(label, text)
456
+ if not matches:
457
+ continue
458
+
459
+ best_key, best_template, score = matches[0]
460
+
461
+ # Tier 3: Try LLM refinement if enabled
462
+ refined_text = None
463
+ tier = "template"
464
+ if use_llm and risk in ("CRITICAL", "HIGH"):
465
+ refined_text = _refine_with_llm(text, best_template, label)
466
+ if refined_text:
467
+ tier = "llm_refined"
468
+
469
+ redlines.append({
470
+ "original_text": text[:500],
471
+ "clause_label": label,
472
+ "risk_level": risk,
473
+ "safe_alternative": refined_text or best_template["safe_alternative"],
474
+ "template_alternative": best_template["safe_alternative"],
475
+ "legal_basis": best_template["legal_basis"],
476
+ "consumer_standard": best_template["consumer_standard"],
477
+ "tier": tier,
478
+ })
479
+
480
+ return redlines
481
+
482
+
483
+ def render_redlines_html(redlines):
484
+ """Render redline suggestions as HTML for Gradio."""
485
+ if not redlines:
486
+ return '''<div style="padding:24px;text-align:center;color:#6b7280;font-family:system-ui,sans-serif;">
487
+ <p style="font-size:16px;">📝 No redline suggestions available.</p>
488
+ <p style="font-size:13px;">Analyze a contract first — redlining suggestions will appear for risky clauses.</p>
489
+ </div>'''
490
+
491
+ risk_styles = {
492
+ "CRITICAL": ("#dc2626", "#fef2f2", "⚠️"),
493
+ "HIGH": ("#ea580c", "#fff7ed", "⚡"),
494
+ "MEDIUM": ("#ca8a04", "#fefce8", "📋"),
495
+ "LOW": ("#16a34a", "#f0fdf4", "✓"),
496
+ }
497
+
498
+ html = '<div style="font-family:system-ui,sans-serif;">'
499
+
500
+ # Summary header
501
+ crit = sum(1 for r in redlines if r["risk_level"] == "CRITICAL")
502
+ high = sum(1 for r in redlines if r["risk_level"] == "HIGH")
503
+ med = sum(1 for r in redlines if r["risk_level"] == "MEDIUM")
504
+ llm_count = sum(1 for r in redlines if r["tier"] == "llm_refined")
505
+
506
+ html += f'''
507
+ <div style="padding:16px;background:linear-gradient(135deg,#eff6ff,#f0fdf4);border-radius:12px;margin-bottom:16px;border:1px solid #e5e7eb;">
508
+ <div style="display:flex;align-items:center;gap:8px;margin-bottom:8px;">
509
+ <span style="font-size:24px;">✏️</span>
510
+ <h2 style="margin:0;font-size:18px;color:#1f2937;">Clause Redlining Suggestions</h2>
511
+ </div>
512
+ <p style="font-size:13px;color:#6b7280;margin:0;">
513
+ {len(redlines)} suggestions: {crit} Critical · {high} High · {med} Medium
514
+ {f" · {llm_count} LLM-refined" if llm_count else ""}
515
+ </p>
516
+ </div>
517
+ '''
518
+
519
+ for i, redline in enumerate(redlines):
520
+ border_color, bg_color, icon = risk_styles.get(
521
+ redline["risk_level"], ("#6b7280", "#f9fafb", "•")
522
+ )
523
+ tier_badge = (
524
+ '<span style="font-size:10px;background:#eff6ff;color:#3b82f6;padding:2px 8px;border-radius:4px;">🤖 LLM Refined</span>'
525
+ if redline["tier"] == "llm_refined"
526
+ else '<span style="font-size:10px;background:#f0fdf4;color:#16a34a;padding:2px 8px;border-radius:4px;">📋 Template</span>'
527
+ )
528
+
529
+ original_preview = redline["original_text"][:200].replace("<", "&lt;").replace(">", "&gt;")
530
+ safe_text = redline["safe_alternative"].replace("<", "&lt;").replace(">", "&gt;")
531
+
532
+ html += f'''
533
+ <div style="border:1px solid #e5e7eb;border-left:4px solid {border_color};border-radius:8px;margin-bottom:12px;overflow:hidden;">
534
+ <!-- Header -->
535
+ <div style="padding:12px 16px;background:{bg_color};border-bottom:1px solid #e5e7eb;">
536
+ <div style="display:flex;align-items:center;justify-content:space-between;">
537
+ <div style="display:flex;align-items:center;gap:8px;">
538
+ <span style="font-size:16px;">{icon}</span>
539
+ <span style="font-size:14px;font-weight:600;color:{border_color};">{redline["clause_label"]}</span>
540
+ <span style="font-size:11px;color:{border_color};text-transform:uppercase;font-weight:600;">{redline["risk_level"]}</span>
541
+ </div>
542
+ {tier_badge}
543
+ </div>
544
+ </div>
545
+
546
+ <!-- Body -->
547
+ <div style="padding:16px;">
548
+ <!-- Original (risky) -->
549
+ <div style="margin-bottom:12px;">
550
+ <div style="font-size:11px;font-weight:600;color:#991b1b;text-transform:uppercase;margin-bottom:4px;">❌ Original (Risky)</div>
551
+ <div style="background:#fef2f2;border:1px solid #fecaca;border-radius:6px;padding:10px;font-size:12px;color:#991b1b;line-height:1.6;">
552
+ <del>{original_preview}{"..." if len(redline["original_text"]) > 200 else ""}</del>
553
+ </div>
554
+ </div>
555
+
556
+ <!-- Suggested (safe) -->
557
+ <div style="margin-bottom:12px;">
558
+ <div style="font-size:11px;font-weight:600;color:#166534;text-transform:uppercase;margin-bottom:4px;">✅ Suggested Alternative</div>
559
+ <div style="background:#f0fdf4;border:1px solid #bbf7d0;border-radius:6px;padding:10px;font-size:12px;color:#166534;line-height:1.6;">
560
+ {safe_text}
561
+ </div>
562
+ </div>
563
+
564
+ <!-- Legal basis -->
565
+ <div style="display:flex;gap:12px;flex-wrap:wrap;">
566
+ <div style="flex:1;min-width:200px;">
567
+ <div style="font-size:10px;font-weight:600;color:#6b7280;text-transform:uppercase;margin-bottom:2px;">📚 Legal Basis</div>
568
+ <div style="font-size:11px;color:#4b5563;">{redline["legal_basis"]}</div>
569
+ </div>
570
+ <div style="flex:1;min-width:200px;">
571
+ <div style="font-size:10px;font-weight:600;color:#6b7280;text-transform:uppercase;margin-bottom:2px;">🛡️ Consumer Standard</div>
572
+ <div style="font-size:11px;color:#4b5563;">{redline["consumer_standard"]}</div>
573
+ </div>
574
+ </div>
575
+ </div>
576
+ </div>
577
+ '''
578
+
579
+ # Disclaimer
580
+ html += '''
581
+ <div style="margin-top:16px;padding:12px;background:#fefce8;border:1px solid #fde68a;border-radius:8px;">
582
+ <p style="font-size:11px;color:#92400e;margin:0;line-height:1.5;">
583
+ <strong>⚠️ Disclaimer:</strong> These are AI-generated suggestions based on legal templates and consumer protection standards.
584
+ They are NOT legal advice. The suggested alternatives are starting points that should be reviewed and customized by a
585
+ qualified attorney before use in any contract. Legal requirements vary by jurisdiction.
586
+ </p>
587
+ </div>
588
+ '''
589
+
590
+ html += '</div>'
591
+ return html
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  gradio>=5.23.0
2
- transformers>=5.6.1
3
  torch>=2.5.0
4
  numpy>=2.0.0
5
  pdfplumber>=0.11.0
@@ -7,3 +7,5 @@ python-docx>=1.1.0
7
  peft>=0.15.0
8
  accelerate>=1.2.0
9
  sentence-transformers>=3.0.0
 
 
 
1
  gradio>=5.23.0
2
+ transformers>=4.45.0
3
  torch>=2.5.0
4
  numpy>=2.0.0
5
  pdfplumber>=0.11.0
 
7
  peft>=0.15.0
8
  accelerate>=1.2.0
9
  sentence-transformers>=3.0.0
10
+ python-doctr[torch]>=0.9.0
11
+ huggingface_hub>=0.25.0
web/.env.example CHANGED
@@ -18,3 +18,10 @@ RESEND_API_KEY=re_...
18
  # App
19
  NEXT_PUBLIC_SITE_URL=http://localhost:3000
20
  CLAUSEGUARD_API_URL=https://gaurv007-clauseguard-api.hf.space
 
 
 
 
 
 
 
 
18
  # App
19
  NEXT_PUBLIC_SITE_URL=http://localhost:3000
20
  CLAUSEGUARD_API_URL=https://gaurv007-clauseguard-api.hf.space
21
+
22
+ # HF Inference API (for chatbot + redlining LLM)
23
+ HF_TOKEN=hf_...
24
+
25
+ # Optional: SaulLM for explain endpoint
26
+ SAULLM_ENDPOINT=
27
+ HF_API_TOKEN=
web/app/api/analyze/route.ts CHANGED
@@ -1,4 +1,5 @@
1
  import { NextRequest, NextResponse } from "next/server";
 
2
 
3
  const API_URL = process.env.CLAUSEGUARD_API_URL || "https://gaurv007-clauseguard-api.hf.space";
4
 
@@ -14,10 +15,19 @@ export async function POST(req: NextRequest) {
14
  );
15
  }
16
 
17
- // Forward to backend API v2.0 (full text, clauses split server-side)
 
 
 
 
 
 
 
 
 
18
  const response = await fetch(`${API_URL}/api/analyze`, {
19
  method: "POST",
20
- headers: { "Content-Type": "application/json" },
21
  body: JSON.stringify({ text, source_url }),
22
  });
23
 
 
1
  import { NextRequest, NextResponse } from "next/server";
2
+ import { createClient } from "@/lib/supabase/server";
3
 
4
  const API_URL = process.env.CLAUSEGUARD_API_URL || "https://gaurv007-clauseguard-api.hf.space";
5
 
 
15
  );
16
  }
17
 
18
+ // Forward auth token to backend
19
+ const headers: Record<string, string> = { "Content-Type": "application/json" };
20
+ try {
21
+ const supabase = await createClient();
22
+ const { data: { session } } = await supabase.auth.getSession();
23
+ if (session?.access_token) {
24
+ headers["Authorization"] = `Bearer ${session.access_token}`;
25
+ }
26
+ } catch {}
27
+
28
  const response = await fetch(`${API_URL}/api/analyze`, {
29
  method: "POST",
30
+ headers,
31
  body: JSON.stringify({ text, source_url }),
32
  });
33
 
web/app/api/chat/route.ts ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from "next/server";
2
+
3
+ const API_URL = process.env.CLAUSEGUARD_API_URL || "https://gaurv007-clauseguard-api.hf.space";
4
+
5
+ export async function POST(req: NextRequest) {
6
+ try {
7
+ const body = await req.json();
8
+ const { message, session_id, history } = body;
9
+
10
+ if (!message || !session_id) {
11
+ return NextResponse.json(
12
+ { error: "message and session_id are required" },
13
+ { status: 400 }
14
+ );
15
+ }
16
+
17
+ const response = await fetch(`${API_URL}/api/chat`, {
18
+ method: "POST",
19
+ headers: { "Content-Type": "application/json" },
20
+ body: JSON.stringify({ message, session_id, history: history || [] }),
21
+ });
22
+
23
+ if (!response.ok) {
24
+ const err = await response.text().catch(() => "");
25
+ throw new Error(err || `Backend error: ${response.status}`);
26
+ }
27
+
28
+ const result = await response.json();
29
+ return NextResponse.json(result);
30
+ } catch (error: any) {
31
+ console.error("Chat error:", error.message);
32
+ return NextResponse.json(
33
+ { error: error.message || "Chat failed. Try again." },
34
+ { status: 500 }
35
+ );
36
+ }
37
+ }
web/app/api/redline/route.ts ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { NextRequest, NextResponse } from "next/server";
2
+
3
+ const API_URL = process.env.CLAUSEGUARD_API_URL || "https://gaurv007-clauseguard-api.hf.space";
4
+
5
+ export async function POST(req: NextRequest) {
6
+ try {
7
+ const body = await req.json();
8
+ const { session_id, text, use_llm } = body;
9
+
10
+ if (!session_id && !text) {
11
+ return NextResponse.json(
12
+ { error: "Provide session_id or text" },
13
+ { status: 400 }
14
+ );
15
+ }
16
+
17
+ const response = await fetch(`${API_URL}/api/redline`, {
18
+ method: "POST",
19
+ headers: { "Content-Type": "application/json" },
20
+ body: JSON.stringify({ session_id, text, use_llm: use_llm ?? true }),
21
+ });
22
+
23
+ if (!response.ok) {
24
+ const err = await response.text().catch(() => "");
25
+ throw new Error(err || `Backend error: ${response.status}`);
26
+ }
27
+
28
+ const result = await response.json();
29
+ return NextResponse.json(result);
30
+ } catch (error: any) {
31
+ console.error("Redline error:", error.message);
32
+ return NextResponse.json(
33
+ { error: error.message || "Redlining failed" },
34
+ { status: 500 }
35
+ );
36
+ }
37
+ }
web/app/dashboard-pages/analyze/page.tsx CHANGED
@@ -9,7 +9,8 @@ import {
9
  AlertTriangle, Tag, BookOpen, ClipboardList, DollarSign,
10
  Calendar, Building, MapPin, Hash, Bot, FileSearch, Percent, Clock,
11
  User, BookMarked, ShieldX, HelpCircle, Cpu, PenTool, Zap,
12
- ShieldOff, CircleSlash, MessageSquareWarning, Construction
 
13
  } from "lucide-react";
14
 
15
  interface Cat { name: string; severity: string; description?: string; confidence?: number; }
@@ -19,6 +20,17 @@ interface Contradiction { type: string; explanation: string; severity: string; c
19
  interface Obligation { type: string; party: string; description: string; deadline: string; priority?: number; }
20
  interface ComplianceCheck { requirement: string; description: string; severity: string; status: string; matched_keywords: string[]; context?: string[]; }
21
  interface ComplianceReg { description: string; compliance_rate: number; checks: ComplianceCheck[]; overall_status: string; negated_count?: number; ambiguous_count?: number; }
 
 
 
 
 
 
 
 
 
 
 
22
  interface AnalysisResult {
23
  risk_score: number;
24
  grade: string;
@@ -29,8 +41,10 @@ interface AnalysisResult {
29
  contradictions: Contradiction[];
30
  obligations: Obligation[];
31
  compliance: Record<string, ComplianceReg>;
 
32
  model: string;
33
  latency_ms: number;
 
34
  }
35
 
36
  const SEV_CONFIG: Record<string, { icon: any; label: string; text: string; bg: string; border: string; ring: string }> = {
@@ -169,6 +183,9 @@ export default function AnalyzePage() {
169
  const [scanLimit, setScanLimit] = useState(10);
170
  const [canUpload, setCanUpload] = useState(false);
171
  const [showUpgrade, setShowUpgrade] = useState(false);
 
 
 
172
  const fileInputRef = useRef<HTMLInputElement>(null);
173
 
174
  // Fetch user profile from DB on mount — no hardcoded emails or plans
@@ -237,6 +254,31 @@ export default function AnalyzePage() {
237
  setCopied(true); setTimeout(() => setCopied(false), 2000);
238
  }
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  const flagged = results?.results.filter(r => r.categories.length > 0) || [];
241
  const filtered = filter === "all" ? flagged : flagged.filter(r => r.categories.some(c => c.severity === filter));
242
  const sevCounts = { CRITICAL: 0, HIGH: 0, MEDIUM: 0, LOW: 0 };
@@ -260,6 +302,8 @@ export default function AnalyzePage() {
260
  { key: "contradictions", label: "Issues", icon: AlertTriangle, count: results?.contradictions.length || 0 },
261
  { key: "obligations", label: "Obligations", icon: ClipboardList, count: results?.obligations.length || 0 },
262
  { key: "compliance", label: "Compliance", icon: ShieldCheck, count: Object.keys(results?.compliance || {}).length },
 
 
263
  ];
264
 
265
  return (
@@ -668,6 +712,139 @@ export default function AnalyzePage() {
668
  })}
669
  </div>
670
  )}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
671
  </div>
672
  </div>
673
  ) : (
 
9
  AlertTriangle, Tag, BookOpen, ClipboardList, DollarSign,
10
  Calendar, Building, MapPin, Hash, Bot, FileSearch, Percent, Clock,
11
  User, BookMarked, ShieldX, HelpCircle, Cpu, PenTool, Zap,
12
+ ShieldOff, CircleSlash, MessageSquareWarning, Construction,
13
+ MessageSquare, Send, Loader2
14
  } from "lucide-react";
15
 
16
  interface Cat { name: string; severity: string; description?: string; confidence?: number; }
 
20
  interface Obligation { type: string; party: string; description: string; deadline: string; priority?: number; }
21
  interface ComplianceCheck { requirement: string; description: string; severity: string; status: string; matched_keywords: string[]; context?: string[]; }
22
  interface ComplianceReg { description: string; compliance_rate: number; checks: ComplianceCheck[]; overall_status: string; negated_count?: number; ambiguous_count?: number; }
23
+ interface Redline {
24
+ original_text: string;
25
+ clause_label: string;
26
+ risk_level: string;
27
+ safe_alternative: string;
28
+ template_alternative?: string;
29
+ legal_basis: string;
30
+ consumer_standard: string;
31
+ tier: string;
32
+ }
33
+ interface ChatMessage { role: "user" | "assistant"; content: string; }
34
  interface AnalysisResult {
35
  risk_score: number;
36
  grade: string;
 
41
  contradictions: Contradiction[];
42
  obligations: Obligation[];
43
  compliance: Record<string, ComplianceReg>;
44
+ redlines: Redline[];
45
  model: string;
46
  latency_ms: number;
47
+ session_id?: string;
48
  }
49
 
50
  const SEV_CONFIG: Record<string, { icon: any; label: string; text: string; bg: string; border: string; ring: string }> = {
 
183
  const [scanLimit, setScanLimit] = useState(10);
184
  const [canUpload, setCanUpload] = useState(false);
185
  const [showUpgrade, setShowUpgrade] = useState(false);
186
+ const [chatMessages, setChatMessages] = useState<ChatMessage[]>([]);
187
+ const [chatInput, setChatInput] = useState("");
188
+ const [chatLoading, setChatLoading] = useState(false);
189
  const fileInputRef = useRef<HTMLInputElement>(null);
190
 
191
  // Fetch user profile from DB on mount — no hardcoded emails or plans
 
254
  setCopied(true); setTimeout(() => setCopied(false), 2000);
255
  }
256
 
257
+ async function handleChat() {
258
+ if (!chatInput.trim() || !results?.session_id) return;
259
+ const userMsg: ChatMessage = { role: "user", content: chatInput.trim() };
260
+ setChatMessages(prev => [...prev, userMsg]);
261
+ setChatInput("");
262
+ setChatLoading(true);
263
+ try {
264
+ const res = await fetch("/api/chat", {
265
+ method: "POST",
266
+ headers: { "Content-Type": "application/json" },
267
+ body: JSON.stringify({
268
+ message: userMsg.content,
269
+ session_id: results.session_id,
270
+ history: chatMessages.slice(-6),
271
+ }),
272
+ });
273
+ if (!res.ok) throw new Error((await res.json()).error || "Chat failed");
274
+ const data = await res.json();
275
+ setChatMessages(prev => [...prev, { role: "assistant", content: data.response }]);
276
+ } catch (e: any) {
277
+ setChatMessages(prev => [...prev, { role: "assistant", content: `⚠️ ${e.message}` }]);
278
+ }
279
+ setChatLoading(false);
280
+ }
281
+
282
  const flagged = results?.results.filter(r => r.categories.length > 0) || [];
283
  const filtered = filter === "all" ? flagged : flagged.filter(r => r.categories.some(c => c.severity === filter));
284
  const sevCounts = { CRITICAL: 0, HIGH: 0, MEDIUM: 0, LOW: 0 };
 
302
  { key: "contradictions", label: "Issues", icon: AlertTriangle, count: results?.contradictions.length || 0 },
303
  { key: "obligations", label: "Obligations", icon: ClipboardList, count: results?.obligations.length || 0 },
304
  { key: "compliance", label: "Compliance", icon: ShieldCheck, count: Object.keys(results?.compliance || {}).length },
305
+ { key: "redlining", label: "Redlining", icon: PenTool, count: results?.redlines?.length || 0 },
306
+ { key: "chat", label: "Q&A", icon: MessageSquare, count: chatMessages.length },
307
  ];
308
 
309
  return (
 
712
  })}
713
  </div>
714
  )}
715
+
716
+ {/* Redlining */}
717
+ {activeTab === "redlining" && (
718
+ <div className="space-y-3">
719
+ {(!results.redlines || results.redlines.length === 0) ? (
720
+ <div className="border border-dashed border-zinc-200 rounded-xl p-8 sm:p-10 text-center bg-white">
721
+ <PenTool className="w-8 h-8 text-zinc-300 mx-auto mb-2" />
722
+ <p className="text-sm text-zinc-500">No redlining suggestions for this contract.</p>
723
+ </div>
724
+ ) : (
725
+ <>
726
+ <div className="bg-gradient-to-r from-blue-50 to-emerald-50 rounded-xl p-4 border border-zinc-200 mb-2">
727
+ <div className="flex items-center gap-2 mb-1">
728
+ <PenTool className="w-4 h-4 text-zinc-600" />
729
+ <span className="text-sm font-semibold text-zinc-800">Clause Redlining Suggestions</span>
730
+ </div>
731
+ <p className="text-xs text-zinc-500">
732
+ {results.redlines.length} suggestions · {results.redlines.filter(r => r.tier === "llm_refined").length} LLM-refined
733
+ </p>
734
+ </div>
735
+ {results.redlines.map((rl, i) => {
736
+ const isHigh = rl.risk_level === "CRITICAL" || rl.risk_level === "HIGH";
737
+ const conf = SEV_CONFIG[rl.risk_level] || SEV_CONFIG.MEDIUM;
738
+ return (
739
+ <div key={i} className={`bg-white border rounded-xl overflow-hidden ${conf.border}`}>
740
+ <div className={`px-4 py-3 ${conf.bg} border-b ${conf.border} flex items-center justify-between`}>
741
+ <div className="flex items-center gap-2">
742
+ <conf.icon className={`w-4 h-4 ${conf.text}`} />
743
+ <span className={`text-sm font-semibold ${conf.text}`}>{rl.clause_label}</span>
744
+ <span className={`text-[10px] uppercase font-bold ${conf.text}`}>{rl.risk_level}</span>
745
+ </div>
746
+ <span className={`text-[10px] px-2 py-0.5 rounded border ${
747
+ rl.tier === "llm_refined"
748
+ ? "bg-indigo-50 text-indigo-600 border-indigo-200"
749
+ : "bg-emerald-50 text-emerald-600 border-emerald-200"
750
+ }`}>
751
+ {rl.tier === "llm_refined" ? "🤖 LLM Refined" : "📋 Template"}
752
+ </span>
753
+ </div>
754
+ <div className="p-4 space-y-3">
755
+ <div>
756
+ <p className="text-[10px] font-semibold text-red-600 uppercase mb-1">❌ Original (Risky)</p>
757
+ <div className="bg-red-50 border border-red-100 rounded-lg p-3 text-xs text-red-800 leading-relaxed line-through">
758
+ {rl.original_text.slice(0, 200)}{rl.original_text.length > 200 ? "..." : ""}
759
+ </div>
760
+ </div>
761
+ <div>
762
+ <p className="text-[10px] font-semibold text-emerald-600 uppercase mb-1">✅ Suggested Alternative</p>
763
+ <div className="bg-emerald-50 border border-emerald-100 rounded-lg p-3 text-xs text-emerald-800 leading-relaxed">
764
+ {rl.safe_alternative}
765
+ </div>
766
+ </div>
767
+ <div className="flex gap-3 flex-wrap text-[10px] text-zinc-500">
768
+ <span>📚 {rl.legal_basis}</span>
769
+ <span>🛡️ {rl.consumer_standard}</span>
770
+ </div>
771
+ </div>
772
+ </div>
773
+ );
774
+ })}
775
+ <div className="bg-amber-50 border border-amber-200 rounded-lg p-3 text-[11px] text-amber-800">
776
+ <strong>⚠️ Disclaimer:</strong> These are AI-generated suggestions, NOT legal advice. Consult an attorney before use.
777
+ </div>
778
+ </>
779
+ )}
780
+ </div>
781
+ )}
782
+
783
+ {/* Chat */}
784
+ {activeTab === "chat" && (
785
+ <div className="flex flex-col h-[350px] sm:h-[420px]">
786
+ {!results.session_id ? (
787
+ <div className="flex-1 flex items-center justify-center">
788
+ <div className="text-center">
789
+ <MessageSquare className="w-8 h-8 text-zinc-300 mx-auto mb-2" />
790
+ <p className="text-sm text-zinc-500">Chat unavailable — session not initialized.</p>
791
+ <p className="text-xs text-zinc-400 mt-1">Try analyzing again with the backend running.</p>
792
+ </div>
793
+ </div>
794
+ ) : (
795
+ <>
796
+ <div className="flex-1 overflow-y-auto space-y-3 pr-1 mb-3">
797
+ {chatMessages.length === 0 && (
798
+ <div className="text-center py-8">
799
+ <MessageSquare className="w-8 h-8 text-zinc-200 mx-auto mb-2" />
800
+ <p className="text-sm text-zinc-400">Ask a question about your contract</p>
801
+ <div className="mt-3 flex flex-wrap justify-center gap-2">
802
+ {["What are the main risks?", "Who are the parties?", "Is there an arbitration clause?", "Summarize key terms"].map(q => (
803
+ <button key={q} onClick={() => { setChatInput(q); }}
804
+ className="text-xs px-3 py-1.5 rounded-full border border-zinc-200 text-zinc-500 hover:bg-zinc-50 transition-colors">
805
+ {q}
806
+ </button>
807
+ ))}
808
+ </div>
809
+ </div>
810
+ )}
811
+ {chatMessages.map((msg, i) => (
812
+ <div key={i} className={`flex ${msg.role === "user" ? "justify-end" : "justify-start"}`}>
813
+ <div className={`max-w-[85%] rounded-xl px-3.5 py-2.5 text-sm leading-relaxed ${
814
+ msg.role === "user"
815
+ ? "bg-zinc-900 text-white"
816
+ : "bg-zinc-100 text-zinc-700 border border-zinc-200"
817
+ }`}>
818
+ {msg.content}
819
+ </div>
820
+ </div>
821
+ ))}
822
+ {chatLoading && (
823
+ <div className="flex justify-start">
824
+ <div className="bg-zinc-100 border border-zinc-200 rounded-xl px-4 py-3">
825
+ <Loader2 className="w-4 h-4 text-zinc-400 animate-spin" />
826
+ </div>
827
+ </div>
828
+ )}
829
+ </div>
830
+ <div className="flex gap-2 border-t border-zinc-100 pt-3">
831
+ <input
832
+ value={chatInput}
833
+ onChange={(e) => setChatInput(e.target.value)}
834
+ onKeyDown={(e) => e.key === "Enter" && !e.shiftKey && handleChat()}
835
+ placeholder="Ask about your contract..."
836
+ className="flex-1 px-3 py-2 border border-zinc-200 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-zinc-900/10"
837
+ disabled={chatLoading}
838
+ />
839
+ <button onClick={handleChat} disabled={chatLoading || !chatInput.trim()}
840
+ className="px-3 py-2 bg-zinc-900 text-white rounded-lg hover:bg-zinc-800 disabled:opacity-40 transition-colors">
841
+ <Send className="w-4 h-4" />
842
+ </button>
843
+ </div>
844
+ </>
845
+ )}
846
+ </div>
847
+ )}
848
  </div>
849
  </div>
850
  ) : (
web/app/page.tsx CHANGED
@@ -3,7 +3,8 @@ import {
3
  ShieldCheck, ShieldAlert, Scale, Gavel, ScanText, FileCheck,
4
  TriangleAlert, ArrowRight, Zap, Eye, Download, ChevronRight,
5
  Sparkles, Lock, Globe, Ban, FileX, Stamp, Layers, Tag, AlertTriangle,
6
- ClipboardList, Landmark, Building, BookOpen, CheckCircle, Cpu
 
7
  } from "lucide-react";
8
 
9
  const CLAUSES = [
@@ -21,22 +22,26 @@ const CLAUSES = [
21
  { icon: ClipboardList, name: "Obligations", desc: "Track monetary, compliance, reporting tasks with priority", severity: "medium" },
22
  { icon: Landmark, name: "Compliance", desc: "GDPR, CCPA, SOX, HIPAA, FINRA with negation detection", severity: "high" },
23
  { icon: BookOpen, name: "Compare Contracts", desc: "Semantic similarity with sentence embeddings", severity: "low" },
 
 
 
 
24
  ];
25
 
26
  const STEPS = [
27
- { icon: Download, title: "Upload or paste", desc: "Drop a PDF, DOCX, or paste contract text directly." },
28
- { icon: ScanText, title: "3 AI models analyze", desc: "Legal-BERT classifier + Legal NER + DeBERTa NLI scan your contract." },
29
- { icon: TriangleAlert, title: "Get precise insights", desc: "Risk score, contradictions, obligations, compliance gaps with source indicators." },
30
  ];
31
 
32
  const PRICING = [
33
  {
34
  name: "Free", price: "0", period: "", highlight: false, cta: "Get started",
35
- features: ["10 scans per month", "41 clause categories", "Risk scoring", "ML Legal NER", "NLI contradiction detection", "Compliance with negation detection"],
36
  },
37
  {
38
  name: "Pro", price: "999", period: "/mo", highlight: true, cta: "Start free trial",
39
- features: ["Unlimited scans", "Upload PDF/DOCX files", "Contract comparison", "AI clause explanations", "Scan history", "PDF report export", "Obligation tracker with priority", "Priority support"],
40
  },
41
  {
42
  name: "Team", price: "3,999", period: "/mo", highlight: false, cta: "Talk to us",
@@ -59,14 +64,14 @@ export default function Home() {
59
  <div className="max-w-2xl">
60
  <div className="inline-flex items-center gap-2 px-3 py-1 rounded-full border border-zinc-200 text-[13px] text-zinc-500 mb-6">
61
  <Sparkles className="w-3.5 h-3.5 text-zinc-400" />
62
- 3 ML models · 41 clause categories · negation-aware compliance
63
  </div>
64
  <h1 className="text-3xl sm:text-[42px] lg:text-5xl font-semibold tracking-tight leading-[1.1]">
65
  Know what you are<br className="hidden sm:block" /> agreeing to
66
  </h1>
67
  <p className="mt-5 text-base sm:text-[17px] text-zinc-500 leading-relaxed max-w-lg">
68
- ClauseGuard scans contracts, terms of service, and leases using 3 specialized AI models.
69
- Get precise clause detection, risk scoring, ML entity extraction, NLI contradiction alerts, and negation-aware compliance checks.
70
  </p>
71
  <div className="mt-8 flex flex-col sm:flex-row gap-3">
72
  <Link href="/dashboard-pages/analyze" className="inline-flex items-center justify-center gap-2 bg-zinc-900 text-white px-5 py-2.5 rounded-lg text-sm font-medium hover:bg-zinc-800 transition-colors">
@@ -87,11 +92,11 @@ export default function Home() {
87
  <ShieldCheck className="w-4 h-4 text-zinc-400" />
88
  <p className="text-[13px] font-medium text-zinc-400 uppercase tracking-wider">Detection</p>
89
  </div>
90
- <h2 className="text-xl sm:text-2xl font-semibold tracking-tight">14 powerful analysis features</h2>
91
  <p className="mt-2 text-zinc-500 text-sm sm:text-[15px] max-w-lg">
92
- Based on the CUAD taxonomy + CLAUDETTE framework, the same datasets used by EU consumer protection researchers and Stanford NLP.
93
  </p>
94
- <div className="mt-8 sm:mt-10 grid grid-cols-2 sm:grid-cols-2 lg:grid-cols-4 gap-2 sm:gap-3">
95
  {CLAUSES.map((c) => (
96
  <div key={c.name} className="group border border-zinc-100 rounded-xl p-3 sm:p-4 hover:border-zinc-200 hover:shadow-sm transition-all cursor-default">
97
  <div className={`w-7 h-7 sm:w-8 sm:h-8 rounded-lg flex items-center justify-center border ${sevColor[c.severity]}`}>
@@ -135,15 +140,15 @@ export default function Home() {
135
  <Cpu className="w-4 h-4 text-zinc-400" />
136
  <p className="text-[13px] font-medium text-zinc-400 uppercase tracking-wider">Technology</p>
137
  </div>
138
- <h2 className="text-xl sm:text-2xl font-semibold tracking-tight">Built on 3 production ML models</h2>
139
  <div className="mt-8 grid sm:grid-cols-2 lg:grid-cols-3 gap-3 sm:gap-4">
140
  {[
141
- { name: "Legal-BERT Classifier", icon: Cpu, desc: "LoRA fine-tuned on 41 CUAD categories with sigmoid multi-label classification and per-class thresholds", source: "Mokshith31/legalbert-contract-clause-classification" },
142
- { name: "Legal-BERT NER", icon: Tag, desc: "ML-based named entity recognition for parties, dates, money, jurisdictions with regex augmentation", source: "matterstack/legal-bert-ner" },
143
- { name: "DeBERTa-v3 NLI", icon: AlertTriangle, desc: "Cross-encoder model for semantic contradiction detection between clause pairs", source: "cross-encoder/nli-deberta-v3-base" },
144
- { name: "Compliance Engine", icon: ShieldCheck, desc: "GDPR, CCPA, SOX, HIPAA, FINRA checking with negation detection and context snippets", source: "Negation-aware keyword + semantic" },
145
- { name: "Obligation Tracker", icon: ClipboardList, desc: "Extracts monetary, compliance, reporting, delivery obligations with priority scoring", source: "Context-filtered regex" },
146
- { name: "Comparison Engine", icon: Layers, desc: "Semantic similarity via sentence-transformers with SequenceMatcher fallback", source: "all-MiniLM-L6-v2" },
147
  ].map((m) => (
148
  <div key={m.name} className="border border-zinc-100 rounded-xl p-4 hover:border-zinc-200 hover:shadow-sm transition-all">
149
  <div className="flex items-center gap-2 mb-2">
@@ -211,7 +216,7 @@ export default function Home() {
211
  <div className="max-w-6xl mx-auto px-4 sm:px-6 py-8 flex flex-col sm:flex-row justify-between items-center gap-4">
212
  <div className="flex items-center gap-2">
213
  <ShieldCheck className="w-4 h-4 text-zinc-300" />
214
- <span className="text-[13px] text-zinc-400">ClauseGuard v3.0 — not legal advice</span>
215
  </div>
216
  <div className="flex gap-5 text-[13px] text-zinc-400">
217
  <Link href="/privacy" className="hover:text-zinc-600">Privacy</Link>
 
3
  ShieldCheck, ShieldAlert, Scale, Gavel, ScanText, FileCheck,
4
  TriangleAlert, ArrowRight, Zap, Eye, Download, ChevronRight,
5
  Sparkles, Lock, Globe, Ban, FileX, Stamp, Layers, Tag, AlertTriangle,
6
+ ClipboardList, Landmark, Building, BookOpen, CheckCircle, Cpu,
7
+ MessageSquare, PenTool, ScanLine
8
  } from "lucide-react";
9
 
10
  const CLAUSES = [
 
22
  { icon: ClipboardList, name: "Obligations", desc: "Track monetary, compliance, reporting tasks with priority", severity: "medium" },
23
  { icon: Landmark, name: "Compliance", desc: "GDPR, CCPA, SOX, HIPAA, FINRA with negation detection", severity: "high" },
24
  { icon: BookOpen, name: "Compare Contracts", desc: "Semantic similarity with sentence embeddings", severity: "low" },
25
+ { icon: PenTool, name: "Clause Redlining", desc: "AI suggests safer alternatives with legal citations", severity: "critical" },
26
+ { icon: MessageSquare, name: "Q&A Chatbot", desc: "Ask questions about your contract — RAG-powered answers", severity: "medium" },
27
+ { icon: ScanLine, name: "OCR for Scanned PDFs", desc: "docTR engine auto-detects and OCRs scanned contracts", severity: "low" },
28
+ { icon: Cpu, name: "6 AI Models", desc: "Legal-BERT, NER, NLI, Embeddings, OCR, Qwen2.5-7B LLM", severity: "low" },
29
  ];
30
 
31
  const STEPS = [
32
+ { icon: Download, title: "Upload or paste", desc: "Drop a PDF (even scanned!), DOCX, or paste contract text directly." },
33
+ { icon: ScanText, title: "6 AI models analyze", desc: "Legal-BERT + NER + NLI + OCR + Embeddings + LLM scan your contract." },
34
+ { icon: TriangleAlert, title: "Get precise insights", desc: "Risk score, redlining, Q&A chatbot, contradictions, obligations, and compliance." },
35
  ];
36
 
37
  const PRICING = [
38
  {
39
  name: "Free", price: "0", period: "", highlight: false, cta: "Get started",
40
+ features: ["10 scans per month", "41 clause categories", "Risk scoring", "ML Legal NER", "NLI contradiction detection", "Compliance with negation detection", "Clause redlining suggestions", "OCR for scanned PDFs"],
41
  },
42
  {
43
  name: "Pro", price: "999", period: "/mo", highlight: true, cta: "Start free trial",
44
+ features: ["Unlimited scans", "Upload PDF/DOCX files", "Contract comparison", "Q&A Chatbot (RAG)", "AI clause explanations", "LLM-refined redlining", "Scan history", "PDF report export", "Obligation tracker with priority", "Priority support"],
45
  },
46
  {
47
  name: "Team", price: "3,999", period: "/mo", highlight: false, cta: "Talk to us",
 
64
  <div className="max-w-2xl">
65
  <div className="inline-flex items-center gap-2 px-3 py-1 rounded-full border border-zinc-200 text-[13px] text-zinc-500 mb-6">
66
  <Sparkles className="w-3.5 h-3.5 text-zinc-400" />
67
+ 6 AI models · 41 clause categories · RAG chatbot · clause redlining · OCR
68
  </div>
69
  <h1 className="text-3xl sm:text-[42px] lg:text-5xl font-semibold tracking-tight leading-[1.1]">
70
  Know what you are<br className="hidden sm:block" /> agreeing to
71
  </h1>
72
  <p className="mt-5 text-base sm:text-[17px] text-zinc-500 leading-relaxed max-w-lg">
73
+ ClauseGuard scans contracts using 6 AI models. Get clause detection, risk scoring,
74
+ safer alternatives, Q&A chatbot, OCR for scanned PDFs, and compliance checks.
75
  </p>
76
  <div className="mt-8 flex flex-col sm:flex-row gap-3">
77
  <Link href="/dashboard-pages/analyze" className="inline-flex items-center justify-center gap-2 bg-zinc-900 text-white px-5 py-2.5 rounded-lg text-sm font-medium hover:bg-zinc-800 transition-colors">
 
92
  <ShieldCheck className="w-4 h-4 text-zinc-400" />
93
  <p className="text-[13px] font-medium text-zinc-400 uppercase tracking-wider">Detection</p>
94
  </div>
95
+ <h2 className="text-xl sm:text-2xl font-semibold tracking-tight">18 powerful analysis features</h2>
96
  <p className="mt-2 text-zinc-500 text-sm sm:text-[15px] max-w-lg">
97
+ Based on the CUAD taxonomy + CLAUDETTE framework. Now with RAG chatbot, clause redlining, and OCR.
98
  </p>
99
+ <div className="mt-8 sm:mt-10 grid grid-cols-2 sm:grid-cols-3 lg:grid-cols-4 gap-2 sm:gap-3">
100
  {CLAUSES.map((c) => (
101
  <div key={c.name} className="group border border-zinc-100 rounded-xl p-3 sm:p-4 hover:border-zinc-200 hover:shadow-sm transition-all cursor-default">
102
  <div className={`w-7 h-7 sm:w-8 sm:h-8 rounded-lg flex items-center justify-center border ${sevColor[c.severity]}`}>
 
140
  <Cpu className="w-4 h-4 text-zinc-400" />
141
  <p className="text-[13px] font-medium text-zinc-400 uppercase tracking-wider">Technology</p>
142
  </div>
143
+ <h2 className="text-xl sm:text-2xl font-semibold tracking-tight">Built on 6 production AI models</h2>
144
  <div className="mt-8 grid sm:grid-cols-2 lg:grid-cols-3 gap-3 sm:gap-4">
145
  {[
146
+ { name: "Legal-BERT Classifier", icon: Cpu, desc: "LoRA fine-tuned on 41 CUAD categories with sigmoid multi-label classification", source: "Mokshith31/legalbert-contract-clause-classification" },
147
+ { name: "Legal-BERT NER", icon: Tag, desc: "Named entity recognition for parties, dates, money, jurisdictions", source: "matterstack/legal-bert-ner" },
148
+ { name: "DeBERTa-v3 NLI", icon: AlertTriangle, desc: "Semantic contradiction detection between clause pairs", source: "cross-encoder/nli-deberta-v3-base" },
149
+ { name: "RAG Chatbot", icon: MessageSquare, desc: "Embedding retrieval + Qwen2.5-7B LLM for contract Q&A", source: "all-MiniLM-L6-v2 + Qwen/Qwen2.5-7B-Instruct" },
150
+ { name: "Clause Redlining", icon: PenTool, desc: "18+ legal templates + LLM refinement for safer clause alternatives", source: "FTC/EU/CFPB standards + Qwen2.5-7B" },
151
+ { name: "docTR OCR", icon: ScanLine, desc: "Smart PDF router: auto-detects scanned PDFs and extracts text", source: "docTR fast_base + crnn_vgg16_bn" },
152
  ].map((m) => (
153
  <div key={m.name} className="border border-zinc-100 rounded-xl p-4 hover:border-zinc-200 hover:shadow-sm transition-all">
154
  <div className="flex items-center gap-2 mb-2">
 
216
  <div className="max-w-6xl mx-auto px-4 sm:px-6 py-8 flex flex-col sm:flex-row justify-between items-center gap-4">
217
  <div className="flex items-center gap-2">
218
  <ShieldCheck className="w-4 h-4 text-zinc-300" />
219
+ <span className="text-[13px] text-zinc-400">ClauseGuard v4.0 — not legal advice</span>
220
  </div>
221
  <div className="flex gap-5 text-[13px] text-zinc-400">
222
  <Link href="/privacy" className="hover:text-zinc-600">Privacy</Link>
web/components/nav.tsx CHANGED
@@ -2,7 +2,7 @@
2
 
3
  import Link from "next/link";
4
  import { usePathname } from "next/navigation";
5
- import { ShieldCheck, Menu, X, Crown, GitCompare } from "lucide-react";
6
  import { useState, useEffect } from "react";
7
  import { createClient } from "@/lib/supabase/client";
8
 
@@ -27,7 +27,6 @@ export function Nav() {
27
  const user = data.user;
28
  setUserEmail(user?.email || null);
29
  if (user) {
30
- // Fetch role from database — no hardcoded emails
31
  const { data: profile } = await supabase
32
  .from("profiles")
33
  .select("role")
@@ -44,7 +43,7 @@ export function Nav() {
44
  <Link href="/" className="flex items-center gap-2">
45
  <ShieldCheck className="w-5 h-5 text-zinc-900" strokeWidth={2.2} />
46
  <span className="font-semibold text-[15px] tracking-tight text-zinc-900">ClauseGuard</span>
47
- <span className="hidden sm:inline text-[10px] font-medium text-zinc-400 ml-1 border border-zinc-200 px-1.5 py-0.5 rounded">v3.0</span>
48
  </Link>
49
 
50
  <div className="hidden md:flex items-center gap-1">
 
2
 
3
  import Link from "next/link";
4
  import { usePathname } from "next/navigation";
5
+ import { ShieldCheck, Menu, X, Crown, GitCompare, MessageSquare } from "lucide-react";
6
  import { useState, useEffect } from "react";
7
  import { createClient } from "@/lib/supabase/client";
8
 
 
27
  const user = data.user;
28
  setUserEmail(user?.email || null);
29
  if (user) {
 
30
  const { data: profile } = await supabase
31
  .from("profiles")
32
  .select("role")
 
43
  <Link href="/" className="flex items-center gap-2">
44
  <ShieldCheck className="w-5 h-5 text-zinc-900" strokeWidth={2.2} />
45
  <span className="font-semibold text-[15px] tracking-tight text-zinc-900">ClauseGuard</span>
46
+ <span className="hidden sm:inline text-[10px] font-medium text-zinc-400 ml-1 border border-zinc-200 px-1.5 py-0.5 rounded">v4.0</span>
47
  </Link>
48
 
49
  <div className="hidden md:flex items-center gap-1">