gaurv007 commited on
Commit
5bc867e
·
verified ·
1 Parent(s): a393ff3

fix(api): v4.1 — sliding window rate limiter, RAG TTL, input validation, proper IP extraction

Browse files
Files changed (1) hide show
  1. api/main.py +1 -435
api/main.py CHANGED
@@ -1,435 +1 @@
1
- """
2
- ClauseGuard — FastAPI Backend v4.0
3
- ══════════════════════════════════
4
- New in v4.0:
5
- • /api/redline — clause redlining suggestions
6
- • /api/chat — RAG chatbot (streaming)
7
- • /api/ocr — OCR scanned PDF extraction
8
- • Updated analysis to include redlining data
9
- """
10
-
11
- import os
12
- import re
13
- import json
14
- import time
15
- import uuid
16
- import tempfile
17
- from contextlib import asynccontextmanager
18
- from typing import Optional
19
- from collections import defaultdict
20
- from datetime import datetime
21
-
22
- import httpx
23
- import numpy as np
24
- from fastapi import FastAPI, HTTPException, Depends, Body, Request, UploadFile, File as FastAPIFile
25
- from fastapi.middleware.cors import CORSMiddleware
26
- from fastapi.responses import StreamingResponse
27
- from pydantic import BaseModel, Field
28
-
29
- from auth import get_current_user, require_auth
30
-
31
- # ── Import shared modules ──
32
- import sys
33
- sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
34
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
35
-
36
- try:
37
- from app import (
38
- split_clauses, classify_cuad, extract_entities,
39
- detect_contradictions, compute_risk_score, analyze_contract,
40
- CUAD_LABELS, RISK_MAP, DESC_MAP, _model_status,
41
- cuad_model, cuad_tokenizer
42
- )
43
- from obligations import extract_obligations
44
- from compliance import check_compliance
45
- from compare import compare_contracts
46
- from redlining import generate_redlines
47
- from chatbot import index_contract, chat_respond
48
- from ocr_engine import parse_pdf_smart, get_ocr_status
49
- _SHARED_MODULES = True
50
- except ImportError as e:
51
- _SHARED_MODULES = False
52
- print(f"[API] WARNING: Could not import shared modules: {e}")
53
-
54
- # ─── Config ───
55
- SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
56
- SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "")
57
- HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "")
58
- SAULLM_ENDPOINT = os.environ.get("SAULLM_ENDPOINT", "")
59
- MAX_TEXT_LENGTH = int(os.environ.get("MAX_TEXT_LENGTH", "100000"))
60
-
61
- # ─── Rate Limiting ───
62
- _rate_limits = {}
63
- RATE_LIMIT_REQUESTS = 30
64
- RATE_LIMIT_WINDOW = 60
65
-
66
- def _check_rate_limit(client_ip: str) -> bool:
67
- now = time.time()
68
- if client_ip in _rate_limits:
69
- count, window_start = _rate_limits[client_ip]
70
- if now - window_start > RATE_LIMIT_WINDOW:
71
- _rate_limits[client_ip] = (1, now)
72
- return True
73
- if count >= RATE_LIMIT_REQUESTS:
74
- return False
75
- _rate_limits[client_ip] = (count + 1, window_start)
76
- return True
77
- _rate_limits[client_ip] = (1, now)
78
- return True
79
-
80
- # ─── Supabase helper ───
81
- async def supabase_insert(table: str, data: dict):
82
- if not SUPABASE_URL or not SUPABASE_SERVICE_KEY:
83
- return
84
- try:
85
- async with httpx.AsyncClient() as client:
86
- await client.post(
87
- f"{SUPABASE_URL}/rest/v1/{table}",
88
- json=data,
89
- headers={
90
- "apikey": SUPABASE_SERVICE_KEY,
91
- "Authorization": f"Bearer {SUPABASE_SERVICE_KEY}",
92
- "Content-Type": "application/json",
93
- "Prefer": "return=minimal",
94
- },
95
- timeout=10.0,
96
- )
97
- except Exception:
98
- pass
99
-
100
- async def supabase_query(table: str, params: dict, headers_extra: dict = {}):
101
- if not SUPABASE_URL or not SUPABASE_SERVICE_KEY:
102
- return []
103
- try:
104
- async with httpx.AsyncClient() as client:
105
- resp = await client.get(
106
- f"{SUPABASE_URL}/rest/v1/{table}",
107
- params=params,
108
- headers={
109
- "apikey": SUPABASE_SERVICE_KEY,
110
- "Authorization": f"Bearer {SUPABASE_SERVICE_KEY}",
111
- **headers_extra,
112
- },
113
- timeout=10.0,
114
- )
115
- return resp.json() if resp.status_code == 200 else []
116
- except Exception:
117
- return []
118
-
119
- # ─── In-memory RAG session store ───
120
- _rag_sessions: dict = {}
121
- _RAG_SESSION_MAX = 100
122
-
123
- # ─── Request/Response Models ───
124
- class AnalyzeRequest(BaseModel):
125
- text: Optional[str] = Field(None, min_length=50)
126
- clauses: Optional[list] = None
127
- source_url: Optional[str] = None
128
-
129
- class CompareRequest(BaseModel):
130
- text_a: str = Field(..., min_length=50)
131
- text_b: str = Field(..., min_length=50)
132
-
133
- class ExplainRequest(BaseModel):
134
- clause: str = Field(..., min_length=10, max_length=2000)
135
- category: str
136
-
137
- class ExplainResponse(BaseModel):
138
- clause: str
139
- category: str
140
- explanation: str
141
- legal_basis: str
142
- recommendation: str
143
-
144
- class ChatRequest(BaseModel):
145
- message: str = Field(..., min_length=1, max_length=2000)
146
- session_id: str
147
- history: Optional[list[dict]] = None
148
-
149
- class RedlineRequest(BaseModel):
150
- session_id: Optional[str] = None
151
- text: Optional[str] = None
152
- use_llm: bool = True
153
-
154
- # ─── App ───
155
- @asynccontextmanager
156
- async def lifespan(app: FastAPI):
157
- yield
158
-
159
- app = FastAPI(title="ClauseGuard API", version="4.0.0", lifespan=lifespan)
160
-
161
- ALLOWED_ORIGINS = [
162
- "https://clauseguardweb.netlify.app",
163
- "http://localhost:3000",
164
- "http://localhost:3001",
165
- ]
166
- app.add_middleware(
167
- CORSMiddleware,
168
- allow_origins=ALLOWED_ORIGINS,
169
- allow_origin_regex=r"^chrome-extension://.*$",
170
- allow_credentials=True,
171
- allow_methods=["*"],
172
- allow_headers=["*"],
173
- )
174
-
175
- @app.get("/health")
176
- async def health():
177
- model_status = "ml" if _SHARED_MODULES and cuad_model else "regex"
178
- ocr_status = get_ocr_status() if _SHARED_MODULES else "unavailable"
179
- return {
180
- "status": "ok",
181
- "model": model_status,
182
- "version": "4.0.0",
183
- "shared_modules": _SHARED_MODULES,
184
- "ocr": ocr_status,
185
- "features": ["analyze", "compare", "redline", "chat", "ocr"],
186
- }
187
-
188
- @app.post("/api/analyze")
189
- async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] = Depends(get_current_user)):
190
- client_ip = request.client.host if request.client else "unknown"
191
- if not _check_rate_limit(client_ip):
192
- raise HTTPException(status_code=429, detail="Rate limit exceeded.")
193
-
194
- text = req.text
195
- if not text and req.clauses:
196
- text = "\n\n".join(req.clauses) if isinstance(req.clauses, list) else str(req.clauses)
197
-
198
- if not text or len(text.strip()) < 50:
199
- raise HTTPException(status_code=400, detail="Text too short (minimum 50 characters)")
200
- if len(text) > MAX_TEXT_LENGTH:
201
- raise HTTPException(status_code=400, detail=f"Text too long (max {MAX_TEXT_LENGTH} chars)")
202
-
203
- start = time.time()
204
-
205
- clauses = split_clauses(text)
206
- if not clauses:
207
- raise HTTPException(status_code=400, detail="No clauses detected")
208
-
209
- clause_results = []
210
- for clause in clauses:
211
- predictions = classify_cuad(clause)
212
- if predictions:
213
- for pred in predictions:
214
- clause_results.append({
215
- "text": clause,
216
- "label": pred["label"],
217
- "confidence": pred["confidence"],
218
- "risk": pred["risk"],
219
- "description": pred["description"],
220
- "source": pred.get("source", "unknown"),
221
- })
222
-
223
- entities = extract_entities(text)
224
- contradictions = detect_contradictions(clause_results, text)
225
- risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses))
226
- obligations = extract_obligations(text)
227
- compliance = check_compliance(text)
228
-
229
- # v4.0: Redlining
230
- analysis_for_redline = {"clauses": clause_results}
231
- redlines = []
232
- try:
233
- redlines = generate_redlines(analysis_for_redline, use_llm=True)
234
- except Exception as e:
235
- print(f"[API] Redlining error: {e}")
236
-
237
- latency = int((time.time() - start) * 1000)
238
-
239
- results_for_db = []
240
- for cr in clause_results:
241
- results_for_db.append({
242
- "text": cr["text"],
243
- "categories": [{
244
- "name": cr["label"],
245
- "severity": cr["risk"],
246
- "confidence": cr["confidence"],
247
- "description": cr["description"],
248
- }],
249
- })
250
-
251
- # v4.0: RAG indexing
252
- session_id = None
253
- try:
254
- chunks, embeddings, _status = index_contract(text)
255
- if chunks and embeddings is not None:
256
- session_id = uuid.uuid4().hex[:12]
257
- if len(_rag_sessions) >= _RAG_SESSION_MAX:
258
- oldest = next(iter(_rag_sessions))
259
- del _rag_sessions[oldest]
260
- _rag_sessions[session_id] = {
261
- "chunks": chunks,
262
- "embeddings": embeddings,
263
- "analysis": {
264
- "risk": {"score": risk, "grade": grade, "breakdown": sev_counts},
265
- "metadata": {"total_clauses": len(clauses), "flagged_clauses": len(clause_results)},
266
- "clauses": clause_results[:30],
267
- "entities": entities[:30],
268
- "contradictions": contradictions,
269
- },
270
- }
271
- except Exception as e:
272
- print(f"[API] RAG indexing error: {e}")
273
-
274
- if user:
275
- await supabase_insert("analyses", {
276
- "user_id": user["id"],
277
- "source_url": req.source_url,
278
- "total_clauses": len(clauses),
279
- "flagged_count": len(set(cr["text"] for cr in clause_results)),
280
- "risk_score": risk,
281
- "grade": grade,
282
- "clauses": results_for_db,
283
- "entities": entities,
284
- "contradictions": contradictions,
285
- "obligations": obligations,
286
- "compliance": compliance,
287
- })
288
-
289
- return {
290
- "risk_score": risk,
291
- "grade": grade,
292
- "total_clauses": len(clauses),
293
- "flagged_count": len(set(cr["text"] for cr in clause_results)),
294
- "results": results_for_db,
295
- "entities": entities,
296
- "contradictions": contradictions,
297
- "obligations": obligations,
298
- "compliance": compliance,
299
- "redlines": redlines,
300
- "model": "ml" if cuad_model else "regex",
301
- "latency_ms": latency,
302
- "session_id": session_id,
303
- }
304
-
305
- @app.post("/api/compare")
306
- async def compare(req: CompareRequest, request: Request):
307
- client_ip = request.client.host if request.client else "unknown"
308
- if not _check_rate_limit(client_ip):
309
- raise HTTPException(status_code=429, detail="Rate limit exceeded.")
310
- return compare_contracts(req.text_a, req.text_b)
311
-
312
- @app.post("/api/redline")
313
- async def redline(req: RedlineRequest, request: Request):
314
- client_ip = request.client.host if request.client else "unknown"
315
- if not _check_rate_limit(client_ip):
316
- raise HTTPException(status_code=429, detail="Rate limit exceeded.")
317
-
318
- if req.session_id and req.session_id in _rag_sessions:
319
- analysis = _rag_sessions[req.session_id]["analysis"]
320
- elif req.text:
321
- result, error = analyze_contract(req.text)
322
- if error:
323
- raise HTTPException(status_code=400, detail=error)
324
- analysis = result
325
- else:
326
- raise HTTPException(status_code=400, detail="Provide session_id or text")
327
-
328
- redlines = generate_redlines(analysis, use_llm=req.use_llm)
329
- return {"redlines": redlines, "count": len(redlines)}
330
-
331
- @app.post("/api/chat")
332
- async def chat(req: ChatRequest, request: Request):
333
- client_ip = request.client.host if request.client else "unknown"
334
- if not _check_rate_limit(client_ip):
335
- raise HTTPException(status_code=429, detail="Rate limit exceeded.")
336
-
337
- if req.session_id not in _rag_sessions:
338
- raise HTTPException(status_code=404, detail="Session not found. Analyze a contract first.")
339
-
340
- session = _rag_sessions[req.session_id]
341
- response_text = ""
342
- for partial in chat_respond(req.message, req.history or [],
343
- session["chunks"], session["embeddings"], session["analysis"]):
344
- response_text = partial
345
-
346
- return {"response": response_text, "session_id": req.session_id}
347
-
348
- @app.post("/api/chat/stream")
349
- async def chat_stream(req: ChatRequest, request: Request):
350
- client_ip = request.client.host if request.client else "unknown"
351
- if not _check_rate_limit(client_ip):
352
- raise HTTPException(status_code=429, detail="Rate limit exceeded.")
353
-
354
- if req.session_id not in _rag_sessions:
355
- raise HTTPException(status_code=404, detail="Session not found.")
356
-
357
- session = _rag_sessions[req.session_id]
358
-
359
- async def generate():
360
- last = ""
361
- for partial in chat_respond(
362
- req.message, req.history or [],
363
- session["chunks"], session["embeddings"], session["analysis"]
364
- ):
365
- delta = partial[len(last):]
366
- last = partial
367
- if delta:
368
- yield f"data: {json.dumps({'delta': delta})}\n\n"
369
- yield "data: [DONE]\n\n"
370
-
371
- return StreamingResponse(generate(), media_type="text/event-stream")
372
-
373
- @app.post("/api/ocr")
374
- async def ocr_endpoint(file: UploadFile = FastAPIFile(...)):
375
- if not file.filename or not file.filename.lower().endswith(".pdf"):
376
- raise HTTPException(status_code=400, detail="Only PDF files supported")
377
-
378
- with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
379
- content = await file.read()
380
- tmp.write(content)
381
- tmp_path = tmp.name
382
-
383
- try:
384
- text, error, method = parse_pdf_smart(tmp_path)
385
- if error:
386
- raise HTTPException(status_code=400, detail=error)
387
- return {"text": text, "method": method, "chars": len(text) if text else 0, "filename": file.filename}
388
- finally:
389
- os.unlink(tmp_path)
390
-
391
- @app.post("/api/explain", response_model=ExplainResponse)
392
- async def explain(req: ExplainRequest, user: dict = Depends(require_auth)):
393
- desc = DESC_MAP.get(req.category, "Unknown category.")
394
- legal = "Consult local consumer protection laws."
395
- recommendation = "Review this clause carefully."
396
-
397
- if SAULLM_ENDPOINT and HF_API_TOKEN:
398
- try:
399
- prompt = (
400
- f"Analyze this contract clause and explain why it may be risky.\n\n"
401
- f"Clause: \"{req.clause}\"\nCategory: {req.category}\n\n"
402
- f"Provide: 1) Plain-English explanation 2) Legal basis 3) Recommendation"
403
- )
404
- async with httpx.AsyncClient(timeout=30.0) as client:
405
- resp = await client.post(
406
- SAULLM_ENDPOINT,
407
- json={"inputs": prompt, "parameters": {"max_new_tokens": 300, "temperature": 0.3}},
408
- headers={"Authorization": f"Bearer {HF_API_TOKEN}"},
409
- )
410
- if resp.status_code == 200:
411
- output = resp.json()
412
- generated = output[0]["generated_text"] if isinstance(output, list) else output.get("generated_text", "")
413
- if generated and len(generated) > 50:
414
- parts = generated.split("\n\n")
415
- desc = parts[0] if len(parts) > 0 else desc
416
- legal = parts[1] if len(parts) > 1 else legal
417
- recommendation = parts[2] if len(parts) > 2 else recommendation
418
- except Exception:
419
- pass
420
-
421
- return ExplainResponse(clause=req.clause, category=req.category,
422
- explanation=desc, legal_basis=legal, recommendation=recommendation)
423
-
424
- @app.get("/api/history")
425
- async def history(user: dict = Depends(require_auth), limit: int = 20, offset: int = 0):
426
- limit = min(limit, 100)
427
- data = await supabase_query("analyses", {
428
- "user_id": f"eq.{user['id']}", "select": "*",
429
- "order": "created_at.desc", "limit": str(limit), "offset": str(offset),
430
- })
431
- return {"analyses": data, "limit": limit, "offset": offset}
432
-
433
- if __name__ == "__main__":
434
- import uvicorn
435
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ file:/app/api_main.py