gaurv007 commited on
Commit
6e7c8ba
·
verified ·
1 Parent(s): 7f5fe5c

v4.0: Backend API — add /api/redline, /api/chat, /api/ocr endpoints

Browse files
Files changed (1) hide show
  1. api/main.py +181 -85
api/main.py CHANGED
@@ -1,19 +1,19 @@
1
  """
2
- ClauseGuard — FastAPI Backend v3.0
3
  ══════════════════════════════════
4
- FIXED in v3.0:
5
- Imports shared modules (no code duplication)
6
- Fixed API schema to accept both {text} and {clauses} from extension
7
- Added rate limiting
8
- Added max text length validation
9
- • Fixed CORS (removed wildcard)
10
- • Added proper error responses
11
  """
12
 
13
  import os
14
  import re
15
  import json
16
  import time
 
 
17
  from contextlib import asynccontextmanager
18
  from typing import Optional
19
  from collections import defaultdict
@@ -21,14 +21,14 @@ from datetime import datetime
21
 
22
  import httpx
23
  import numpy as np
24
- from fastapi import FastAPI, HTTPException, Depends, Body, Request
25
  from fastapi.middleware.cors import CORSMiddleware
 
26
  from pydantic import BaseModel, Field
27
 
28
  from auth import get_current_user, require_auth
29
 
30
  # ── Import shared modules ──
31
- # When deployed, these must be in the same directory or on PYTHONPATH
32
  import sys
33
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
34
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -36,29 +36,32 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
36
  try:
37
  from app import (
38
  split_clauses, classify_cuad, extract_entities,
39
- detect_contradictions, compute_risk_score,
40
  CUAD_LABELS, RISK_MAP, DESC_MAP, _model_status,
41
  cuad_model, cuad_tokenizer
42
  )
43
  from obligations import extract_obligations
44
  from compliance import check_compliance
45
  from compare import compare_contracts
 
 
 
46
  _SHARED_MODULES = True
47
- except ImportError:
48
  _SHARED_MODULES = False
49
- print("[API] WARNING: Could not import shared modules, using inline fallbacks")
50
 
51
  # ─── Config ───
52
  SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
53
  SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "")
54
  HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "")
55
  SAULLM_ENDPOINT = os.environ.get("SAULLM_ENDPOINT", "")
56
- MAX_TEXT_LENGTH = int(os.environ.get("MAX_TEXT_LENGTH", "100000")) # 100KB default
57
 
58
  # ─── Rate Limiting ───
59
- _rate_limits = {} # ip -> (count, window_start)
60
  RATE_LIMIT_REQUESTS = 30
61
- RATE_LIMIT_WINDOW = 60 # seconds
62
 
63
  def _check_rate_limit(client_ip: str) -> bool:
64
  now = time.time()
@@ -113,25 +116,16 @@ async def supabase_query(table: str, params: dict, headers_extra: dict = {}):
113
  except Exception:
114
  return []
115
 
 
 
 
 
116
  # ─── Request/Response Models ───
117
  class AnalyzeRequest(BaseModel):
118
  text: Optional[str] = Field(None, min_length=50)
119
- clauses: Optional[list] = None # FIXED: accept clauses array from extension
120
  source_url: Optional[str] = None
121
 
122
- class AnalyzeResponse(BaseModel):
123
- risk_score: int
124
- grade: str
125
- total_clauses: int
126
- flagged_count: int
127
- results: list[dict]
128
- entities: list[dict]
129
- contradictions: list[dict]
130
- obligations: list[dict]
131
- compliance: dict
132
- model: str
133
- latency_ms: int
134
-
135
  class CompareRequest(BaseModel):
136
  text_a: str = Field(..., min_length=50)
137
  text_b: str = Field(..., min_length=50)
@@ -147,21 +141,28 @@ class ExplainResponse(BaseModel):
147
  legal_basis: str
148
  recommendation: str
149
 
 
 
 
 
 
 
 
 
 
 
150
  # ─── App ───
151
  @asynccontextmanager
152
  async def lifespan(app: FastAPI):
153
- # Models are loaded when app.py is imported
154
  yield
155
 
156
- app = FastAPI(title="ClauseGuard API", version="3.0.0", lifespan=lifespan)
157
 
158
- # FIXED: No wildcard CORS
159
  ALLOWED_ORIGINS = [
160
  "https://clauseguardweb.netlify.app",
161
  "http://localhost:3000",
162
  "http://localhost:3001",
163
  ]
164
- # Allow chrome extensions
165
  app.add_middleware(
166
  CORSMiddleware,
167
  allow_origins=ALLOWED_ORIGINS,
@@ -174,36 +175,36 @@ app.add_middleware(
174
  @app.get("/health")
175
  async def health():
176
  model_status = "ml" if _SHARED_MODULES and cuad_model else "regex"
 
177
  return {
178
  "status": "ok",
179
  "model": model_status,
180
- "version": "3.0.0",
181
  "shared_modules": _SHARED_MODULES,
 
 
182
  }
183
 
184
- @app.post("/api/analyze", response_model=AnalyzeResponse)
185
  async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] = Depends(get_current_user)):
186
- # Rate limiting
187
  client_ip = request.client.host if request.client else "unknown"
188
  if not _check_rate_limit(client_ip):
189
- raise HTTPException(status_code=429, detail="Rate limit exceeded. Try again in 60 seconds.")
190
 
191
- # FIXED: Accept either text or clauses from extension
192
  text = req.text
193
  if not text and req.clauses:
194
  text = "\n\n".join(req.clauses) if isinstance(req.clauses, list) else str(req.clauses)
195
 
196
  if not text or len(text.strip()) < 50:
197
  raise HTTPException(status_code=400, detail="Text too short (minimum 50 characters)")
198
-
199
- # Max length check
200
  if len(text) > MAX_TEXT_LENGTH:
201
- raise HTTPException(status_code=400, detail=f"Text too long (maximum {MAX_TEXT_LENGTH} characters)")
202
 
203
  start = time.time()
 
204
  clauses = split_clauses(text)
205
  if not clauses:
206
- raise HTTPException(status_code=400, detail="No clauses detected in document")
207
 
208
  clause_results = []
209
  for clause in clauses:
@@ -224,6 +225,15 @@ async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] =
224
  risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses))
225
  obligations = extract_obligations(text)
226
  compliance = check_compliance(text)
 
 
 
 
 
 
 
 
 
227
  latency = int((time.time() - start) * 1000)
228
 
229
  results_for_db = []
@@ -238,6 +248,29 @@ async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] =
238
  }],
239
  })
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  if user:
242
  await supabase_insert("analyses", {
243
  "user_id": user["id"],
@@ -253,46 +286,120 @@ async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] =
253
  "compliance": compliance,
254
  })
255
 
256
- return AnalyzeResponse(
257
- risk_score=risk,
258
- grade=grade,
259
- total_clauses=len(clauses),
260
- flagged_count=len(set(cr["text"] for cr in clause_results)),
261
- results=results_for_db,
262
- entities=entities,
263
- contradictions=contradictions,
264
- obligations=obligations,
265
- compliance=compliance,
266
- model="ml" if cuad_model else "regex",
267
- latency_ms=latency,
268
- )
 
 
269
 
270
  @app.post("/api/compare")
271
  async def compare(req: CompareRequest, request: Request):
272
  client_ip = request.client.host if request.client else "unknown"
273
  if not _check_rate_limit(client_ip):
274
  raise HTTPException(status_code=429, detail="Rate limit exceeded.")
275
- result = compare_contracts(req.text_a, req.text_b)
276
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  @app.post("/api/explain", response_model=ExplainResponse)
279
  async def explain(req: ExplainRequest, user: dict = Depends(require_auth)):
280
  desc = DESC_MAP.get(req.category, "Unknown category.")
281
  legal = "Consult local consumer protection laws."
282
- recommendation = "Review this clause carefully. Consider negotiating or seeking legal advice before agreeing."
283
 
284
  if SAULLM_ENDPOINT and HF_API_TOKEN:
285
  try:
286
  prompt = (
287
- f"You are a consumer protection legal analyst. Analyze this contract clause "
288
- f"and explain why it may be unfair or risky.\n\n"
289
- f"Clause: \"{req.clause}\"\n"
290
- f"Category: {req.category}\n\n"
291
- f"Provide:\n"
292
- f"1. A plain-English explanation of what this clause means\n"
293
- f"2. The specific legal basis or consumer protection concern\n"
294
- f"3. A practical recommendation\n\n"
295
- f"Be concise. 3-4 sentences per section."
296
  )
297
  async with httpx.AsyncClient(timeout=30.0) as client:
298
  resp = await client.post(
@@ -311,27 +418,16 @@ async def explain(req: ExplainRequest, user: dict = Depends(require_auth)):
311
  except Exception:
312
  pass
313
 
314
- return ExplainResponse(
315
- clause=req.clause,
316
- category=req.category,
317
- explanation=desc,
318
- legal_basis=legal,
319
- recommendation=recommendation,
320
- )
321
 
322
  @app.get("/api/history")
323
  async def history(user: dict = Depends(require_auth), limit: int = 20, offset: int = 0):
324
  limit = min(limit, 100)
325
- data = await supabase_query(
326
- "analyses",
327
- {
328
- "user_id": f"eq.{user['id']}",
329
- "select": "*",
330
- "order": "created_at.desc",
331
- "limit": str(limit),
332
- "offset": str(offset),
333
- },
334
- )
335
  return {"analyses": data, "limit": limit, "offset": offset}
336
 
337
  if __name__ == "__main__":
 
1
  """
2
+ ClauseGuard — FastAPI Backend v4.0
3
  ══════════════════════════════════
4
+ New in v4.0:
5
+ /api/redline clause redlining suggestions
6
+ /api/chat RAG chatbot (streaming)
7
+ /api/ocr OCR scanned PDF extraction
8
+ Updated analysis to include redlining data
 
 
9
  """
10
 
11
  import os
12
  import re
13
  import json
14
  import time
15
+ import uuid
16
+ import tempfile
17
  from contextlib import asynccontextmanager
18
  from typing import Optional
19
  from collections import defaultdict
 
21
 
22
  import httpx
23
  import numpy as np
24
+ from fastapi import FastAPI, HTTPException, Depends, Body, Request, UploadFile, File as FastAPIFile
25
  from fastapi.middleware.cors import CORSMiddleware
26
+ from fastapi.responses import StreamingResponse
27
  from pydantic import BaseModel, Field
28
 
29
  from auth import get_current_user, require_auth
30
 
31
  # ── Import shared modules ──
 
32
  import sys
33
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
34
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
36
  try:
37
  from app import (
38
  split_clauses, classify_cuad, extract_entities,
39
+ detect_contradictions, compute_risk_score, analyze_contract,
40
  CUAD_LABELS, RISK_MAP, DESC_MAP, _model_status,
41
  cuad_model, cuad_tokenizer
42
  )
43
  from obligations import extract_obligations
44
  from compliance import check_compliance
45
  from compare import compare_contracts
46
+ from redlining import generate_redlines
47
+ from chatbot import index_contract, chat_respond
48
+ from ocr_engine import parse_pdf_smart, get_ocr_status
49
  _SHARED_MODULES = True
50
+ except ImportError as e:
51
  _SHARED_MODULES = False
52
+ print(f"[API] WARNING: Could not import shared modules: {e}")
53
 
54
  # ─── Config ───
55
  SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
56
  SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "")
57
  HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "")
58
  SAULLM_ENDPOINT = os.environ.get("SAULLM_ENDPOINT", "")
59
+ MAX_TEXT_LENGTH = int(os.environ.get("MAX_TEXT_LENGTH", "100000"))
60
 
61
  # ─── Rate Limiting ───
62
+ _rate_limits = {}
63
  RATE_LIMIT_REQUESTS = 30
64
+ RATE_LIMIT_WINDOW = 60
65
 
66
  def _check_rate_limit(client_ip: str) -> bool:
67
  now = time.time()
 
116
  except Exception:
117
  return []
118
 
119
+ # ─── In-memory RAG session store ───
120
+ _rag_sessions: dict = {}
121
+ _RAG_SESSION_MAX = 100
122
+
123
  # ─── Request/Response Models ───
124
  class AnalyzeRequest(BaseModel):
125
  text: Optional[str] = Field(None, min_length=50)
126
+ clauses: Optional[list] = None
127
  source_url: Optional[str] = None
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  class CompareRequest(BaseModel):
130
  text_a: str = Field(..., min_length=50)
131
  text_b: str = Field(..., min_length=50)
 
141
  legal_basis: str
142
  recommendation: str
143
 
144
+ class ChatRequest(BaseModel):
145
+ message: str = Field(..., min_length=1, max_length=2000)
146
+ session_id: str
147
+ history: Optional[list[dict]] = None
148
+
149
+ class RedlineRequest(BaseModel):
150
+ session_id: Optional[str] = None
151
+ text: Optional[str] = None
152
+ use_llm: bool = True
153
+
154
  # ─── App ───
155
  @asynccontextmanager
156
  async def lifespan(app: FastAPI):
 
157
  yield
158
 
159
+ app = FastAPI(title="ClauseGuard API", version="4.0.0", lifespan=lifespan)
160
 
 
161
  ALLOWED_ORIGINS = [
162
  "https://clauseguardweb.netlify.app",
163
  "http://localhost:3000",
164
  "http://localhost:3001",
165
  ]
 
166
  app.add_middleware(
167
  CORSMiddleware,
168
  allow_origins=ALLOWED_ORIGINS,
 
175
  @app.get("/health")
176
  async def health():
177
  model_status = "ml" if _SHARED_MODULES and cuad_model else "regex"
178
+ ocr_status = get_ocr_status() if _SHARED_MODULES else "unavailable"
179
  return {
180
  "status": "ok",
181
  "model": model_status,
182
+ "version": "4.0.0",
183
  "shared_modules": _SHARED_MODULES,
184
+ "ocr": ocr_status,
185
+ "features": ["analyze", "compare", "redline", "chat", "ocr"],
186
  }
187
 
188
+ @app.post("/api/analyze")
189
  async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] = Depends(get_current_user)):
 
190
  client_ip = request.client.host if request.client else "unknown"
191
  if not _check_rate_limit(client_ip):
192
+ raise HTTPException(status_code=429, detail="Rate limit exceeded.")
193
 
 
194
  text = req.text
195
  if not text and req.clauses:
196
  text = "\n\n".join(req.clauses) if isinstance(req.clauses, list) else str(req.clauses)
197
 
198
  if not text or len(text.strip()) < 50:
199
  raise HTTPException(status_code=400, detail="Text too short (minimum 50 characters)")
 
 
200
  if len(text) > MAX_TEXT_LENGTH:
201
+ raise HTTPException(status_code=400, detail=f"Text too long (max {MAX_TEXT_LENGTH} chars)")
202
 
203
  start = time.time()
204
+
205
  clauses = split_clauses(text)
206
  if not clauses:
207
+ raise HTTPException(status_code=400, detail="No clauses detected")
208
 
209
  clause_results = []
210
  for clause in clauses:
 
225
  risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses))
226
  obligations = extract_obligations(text)
227
  compliance = check_compliance(text)
228
+
229
+ # v4.0: Redlining
230
+ analysis_for_redline = {"clauses": clause_results}
231
+ redlines = []
232
+ try:
233
+ redlines = generate_redlines(analysis_for_redline, use_llm=True)
234
+ except Exception as e:
235
+ print(f"[API] Redlining error: {e}")
236
+
237
  latency = int((time.time() - start) * 1000)
238
 
239
  results_for_db = []
 
248
  }],
249
  })
250
 
251
+ # v4.0: RAG indexing
252
+ session_id = None
253
+ try:
254
+ chunks, embeddings, _status = index_contract(text)
255
+ if chunks and embeddings is not None:
256
+ session_id = uuid.uuid4().hex[:12]
257
+ if len(_rag_sessions) >= _RAG_SESSION_MAX:
258
+ oldest = next(iter(_rag_sessions))
259
+ del _rag_sessions[oldest]
260
+ _rag_sessions[session_id] = {
261
+ "chunks": chunks,
262
+ "embeddings": embeddings,
263
+ "analysis": {
264
+ "risk": {"score": risk, "grade": grade, "breakdown": sev_counts},
265
+ "metadata": {"total_clauses": len(clauses), "flagged_clauses": len(clause_results)},
266
+ "clauses": clause_results[:30],
267
+ "entities": entities[:30],
268
+ "contradictions": contradictions,
269
+ },
270
+ }
271
+ except Exception as e:
272
+ print(f"[API] RAG indexing error: {e}")
273
+
274
  if user:
275
  await supabase_insert("analyses", {
276
  "user_id": user["id"],
 
286
  "compliance": compliance,
287
  })
288
 
289
+ return {
290
+ "risk_score": risk,
291
+ "grade": grade,
292
+ "total_clauses": len(clauses),
293
+ "flagged_count": len(set(cr["text"] for cr in clause_results)),
294
+ "results": results_for_db,
295
+ "entities": entities,
296
+ "contradictions": contradictions,
297
+ "obligations": obligations,
298
+ "compliance": compliance,
299
+ "redlines": redlines,
300
+ "model": "ml" if cuad_model else "regex",
301
+ "latency_ms": latency,
302
+ "session_id": session_id,
303
+ }
304
 
305
  @app.post("/api/compare")
306
  async def compare(req: CompareRequest, request: Request):
307
  client_ip = request.client.host if request.client else "unknown"
308
  if not _check_rate_limit(client_ip):
309
  raise HTTPException(status_code=429, detail="Rate limit exceeded.")
310
+ return compare_contracts(req.text_a, req.text_b)
311
+
312
+ @app.post("/api/redline")
313
+ async def redline(req: RedlineRequest, request: Request):
314
+ client_ip = request.client.host if request.client else "unknown"
315
+ if not _check_rate_limit(client_ip):
316
+ raise HTTPException(status_code=429, detail="Rate limit exceeded.")
317
+
318
+ if req.session_id and req.session_id in _rag_sessions:
319
+ analysis = _rag_sessions[req.session_id]["analysis"]
320
+ elif req.text:
321
+ result, error = analyze_contract(req.text)
322
+ if error:
323
+ raise HTTPException(status_code=400, detail=error)
324
+ analysis = result
325
+ else:
326
+ raise HTTPException(status_code=400, detail="Provide session_id or text")
327
+
328
+ redlines = generate_redlines(analysis, use_llm=req.use_llm)
329
+ return {"redlines": redlines, "count": len(redlines)}
330
+
331
+ @app.post("/api/chat")
332
+ async def chat(req: ChatRequest, request: Request):
333
+ client_ip = request.client.host if request.client else "unknown"
334
+ if not _check_rate_limit(client_ip):
335
+ raise HTTPException(status_code=429, detail="Rate limit exceeded.")
336
+
337
+ if req.session_id not in _rag_sessions:
338
+ raise HTTPException(status_code=404, detail="Session not found. Analyze a contract first.")
339
+
340
+ session = _rag_sessions[req.session_id]
341
+ response_text = ""
342
+ for partial in chat_respond(req.message, req.history or [],
343
+ session["chunks"], session["embeddings"], session["analysis"]):
344
+ response_text = partial
345
+
346
+ return {"response": response_text, "session_id": req.session_id}
347
+
348
+ @app.post("/api/chat/stream")
349
+ async def chat_stream(req: ChatRequest, request: Request):
350
+ client_ip = request.client.host if request.client else "unknown"
351
+ if not _check_rate_limit(client_ip):
352
+ raise HTTPException(status_code=429, detail="Rate limit exceeded.")
353
+
354
+ if req.session_id not in _rag_sessions:
355
+ raise HTTPException(status_code=404, detail="Session not found.")
356
+
357
+ session = _rag_sessions[req.session_id]
358
+
359
+ async def generate():
360
+ last = ""
361
+ for partial in chat_respond(
362
+ req.message, req.history or [],
363
+ session["chunks"], session["embeddings"], session["analysis"]
364
+ ):
365
+ delta = partial[len(last):]
366
+ last = partial
367
+ if delta:
368
+ yield f"data: {json.dumps({'delta': delta})}\n\n"
369
+ yield "data: [DONE]\n\n"
370
+
371
+ return StreamingResponse(generate(), media_type="text/event-stream")
372
+
373
+ @app.post("/api/ocr")
374
+ async def ocr_endpoint(file: UploadFile = FastAPIFile(...)):
375
+ if not file.filename or not file.filename.lower().endswith(".pdf"):
376
+ raise HTTPException(status_code=400, detail="Only PDF files supported")
377
+
378
+ with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
379
+ content = await file.read()
380
+ tmp.write(content)
381
+ tmp_path = tmp.name
382
+
383
+ try:
384
+ text, error, method = parse_pdf_smart(tmp_path)
385
+ if error:
386
+ raise HTTPException(status_code=400, detail=error)
387
+ return {"text": text, "method": method, "chars": len(text) if text else 0, "filename": file.filename}
388
+ finally:
389
+ os.unlink(tmp_path)
390
 
391
  @app.post("/api/explain", response_model=ExplainResponse)
392
  async def explain(req: ExplainRequest, user: dict = Depends(require_auth)):
393
  desc = DESC_MAP.get(req.category, "Unknown category.")
394
  legal = "Consult local consumer protection laws."
395
+ recommendation = "Review this clause carefully."
396
 
397
  if SAULLM_ENDPOINT and HF_API_TOKEN:
398
  try:
399
  prompt = (
400
+ f"Analyze this contract clause and explain why it may be risky.\n\n"
401
+ f"Clause: \"{req.clause}\"\nCategory: {req.category}\n\n"
402
+ f"Provide: 1) Plain-English explanation 2) Legal basis 3) Recommendation"
 
 
 
 
 
 
403
  )
404
  async with httpx.AsyncClient(timeout=30.0) as client:
405
  resp = await client.post(
 
418
  except Exception:
419
  pass
420
 
421
+ return ExplainResponse(clause=req.clause, category=req.category,
422
+ explanation=desc, legal_basis=legal, recommendation=recommendation)
 
 
 
 
 
423
 
424
  @app.get("/api/history")
425
  async def history(user: dict = Depends(require_auth), limit: int = 20, offset: int = 0):
426
  limit = min(limit, 100)
427
+ data = await supabase_query("analyses", {
428
+ "user_id": f"eq.{user['id']}", "select": "*",
429
+ "order": "created_at.desc", "limit": str(limit), "offset": str(offset),
430
+ })
 
 
 
 
 
 
431
  return {"analyses": data, "limit": limit, "offset": offset}
432
 
433
  if __name__ == "__main__":