gaurv007 commited on
Commit
2f76184
·
verified ·
1 Parent(s): 50f2675

fix: upload actual api/main.py content with all v4.1 fixes

Browse files
Files changed (1) hide show
  1. api/main.py +486 -1
api/main.py CHANGED
@@ -1 +1,486 @@
1
- file:/app/api_main.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ClauseGuard — FastAPI Backend v4.1
3
+ ══════════════════════════════════
4
+ Fixes in v4.1:
5
+ • FIX: Rate limiter uses sliding window with proper IP extraction (X-Forwarded-For)
6
+ • FIX: RAG sessions have TTL-based expiry (1 hour) instead of just count-based
7
+ • FIX: Input text size validation (max 200KB)
8
+ • FIX: Proper error handling for all endpoints
9
+ """
10
+
11
+ import os
12
+ import re
13
+ import json
14
+ import time
15
+ import uuid
16
+ import tempfile
17
+ from contextlib import asynccontextmanager
18
+ from typing import Optional
19
+ from collections import defaultdict
20
+ from datetime import datetime
21
+
22
+ import httpx
23
+ import numpy as np
24
+ from fastapi import FastAPI, HTTPException, Depends, Body, Request, UploadFile, File as FastAPIFile
25
+ from fastapi.middleware.cors import CORSMiddleware
26
+ from fastapi.responses import StreamingResponse
27
+ from pydantic import BaseModel, Field
28
+
29
+ from auth import get_current_user, require_auth
30
+
31
+ # ── Import shared modules ──
32
+ import sys
33
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
34
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
35
+
36
+ try:
37
+ from app import (
38
+ split_clauses, classify_cuad, extract_entities,
39
+ detect_contradictions, compute_risk_score, analyze_contract,
40
+ CUAD_LABELS, RISK_MAP, DESC_MAP, _model_status,
41
+ cuad_model, cuad_tokenizer
42
+ )
43
+ from obligations import extract_obligations
44
+ from compliance import check_compliance
45
+ from compare import compare_contracts
46
+ from redlining import generate_redlines
47
+ from chatbot import index_contract, chat_respond
48
+ from ocr_engine import parse_pdf_smart, get_ocr_status
49
+ _SHARED_MODULES = True
50
+ except ImportError as e:
51
+ _SHARED_MODULES = False
52
+ print(f"[API] WARNING: Could not import shared modules: {e}")
53
+
54
+ # ─── Config ───
55
+ SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
56
+ SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "")
57
+ HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "")
58
+ SAULLM_ENDPOINT = os.environ.get("SAULLM_ENDPOINT", "")
59
+ MAX_TEXT_LENGTH = int(os.environ.get("MAX_TEXT_LENGTH", "200000"))
60
+
61
+ # ─── FIX v4.1: Sliding window rate limiter with proper IP extraction ───
62
+ _rate_limits: dict[str, list[float]] = {}
63
+ RATE_LIMIT_REQUESTS = 30
64
+ RATE_LIMIT_WINDOW = 60 # seconds
65
+
66
+ def _get_client_ip(request: Request) -> str:
67
+ """Extract real client IP, handling reverse proxies."""
68
+ forwarded = request.headers.get("x-forwarded-for", "")
69
+ if forwarded:
70
+ return forwarded.split(",")[0].strip()
71
+ return request.client.host if request.client else "unknown"
72
+
73
+ def _check_rate_limit(client_ip: str) -> bool:
74
+ """Sliding window rate limiter."""
75
+ now = time.time()
76
+ if client_ip not in _rate_limits:
77
+ _rate_limits[client_ip] = []
78
+
79
+ # Remove expired timestamps
80
+ _rate_limits[client_ip] = [
81
+ t for t in _rate_limits[client_ip] if now - t < RATE_LIMIT_WINDOW
82
+ ]
83
+
84
+ if len(_rate_limits[client_ip]) >= RATE_LIMIT_REQUESTS:
85
+ return False
86
+
87
+ _rate_limits[client_ip].append(now)
88
+
89
+ # Periodic cleanup of stale IPs (every 100 requests)
90
+ if len(_rate_limits) > 1000:
91
+ stale = [ip for ip, ts in _rate_limits.items() if not ts or now - ts[-1] > RATE_LIMIT_WINDOW * 2]
92
+ for ip in stale:
93
+ del _rate_limits[ip]
94
+
95
+ return True
96
+
97
+ # ─── Supabase helper ───
98
+ async def supabase_insert(table: str, data: dict):
99
+ if not SUPABASE_URL or not SUPABASE_SERVICE_KEY:
100
+ return
101
+ try:
102
+ async with httpx.AsyncClient() as client:
103
+ await client.post(
104
+ f"{SUPABASE_URL}/rest/v1/{table}",
105
+ json=data,
106
+ headers={
107
+ "apikey": SUPABASE_SERVICE_KEY,
108
+ "Authorization": f"Bearer {SUPABASE_SERVICE_KEY}",
109
+ "Content-Type": "application/json",
110
+ "Prefer": "return=minimal",
111
+ },
112
+ timeout=10.0,
113
+ )
114
+ except Exception:
115
+ pass
116
+
117
+ async def supabase_query(table: str, params: dict, headers_extra: dict = {}):
118
+ if not SUPABASE_URL or not SUPABASE_SERVICE_KEY:
119
+ return []
120
+ try:
121
+ async with httpx.AsyncClient() as client:
122
+ resp = await client.get(
123
+ f"{SUPABASE_URL}/rest/v1/{table}",
124
+ params=params,
125
+ headers={
126
+ "apikey": SUPABASE_SERVICE_KEY,
127
+ "Authorization": f"Bearer {SUPABASE_SERVICE_KEY}",
128
+ **headers_extra,
129
+ },
130
+ timeout=10.0,
131
+ )
132
+ return resp.json() if resp.status_code == 200 else []
133
+ except Exception:
134
+ return []
135
+
136
+ # ─── FIX v4.1: RAG sessions with TTL-based expiry ───
137
+ _rag_sessions: dict[str, dict] = {}
138
+ _RAG_SESSION_MAX = 100
139
+ _RAG_SESSION_TTL = 3600 # 1 hour
140
+
141
+ def _cleanup_rag_sessions():
142
+ """Remove expired RAG sessions."""
143
+ now = time.time()
144
+ expired = [sid for sid, s in _rag_sessions.items() if now - s.get("created_at", 0) > _RAG_SESSION_TTL]
145
+ for sid in expired:
146
+ del _rag_sessions[sid]
147
+
148
+ def _store_rag_session(session_id: str, data: dict):
149
+ """Store a RAG session with TTL tracking."""
150
+ _cleanup_rag_sessions()
151
+ if len(_rag_sessions) >= _RAG_SESSION_MAX:
152
+ # Remove oldest session
153
+ oldest = min(_rag_sessions, key=lambda k: _rag_sessions[k].get("created_at", 0))
154
+ del _rag_sessions[oldest]
155
+ data["created_at"] = time.time()
156
+ _rag_sessions[session_id] = data
157
+
158
+ # ─── Request/Response Models ───
159
+ class AnalyzeRequest(BaseModel):
160
+ text: Optional[str] = Field(None, min_length=50)
161
+ clauses: Optional[list] = None
162
+ source_url: Optional[str] = None
163
+
164
+ class CompareRequest(BaseModel):
165
+ text_a: str = Field(..., min_length=50)
166
+ text_b: str = Field(..., min_length=50)
167
+
168
+ class ExplainRequest(BaseModel):
169
+ clause: str = Field(..., min_length=10, max_length=2000)
170
+ category: str
171
+
172
+ class ExplainResponse(BaseModel):
173
+ clause: str
174
+ category: str
175
+ explanation: str
176
+ legal_basis: str
177
+ recommendation: str
178
+
179
+ class ChatRequest(BaseModel):
180
+ message: str = Field(..., min_length=1, max_length=2000)
181
+ session_id: str
182
+ history: Optional[list[dict]] = None
183
+
184
+ class RedlineRequest(BaseModel):
185
+ session_id: Optional[str] = None
186
+ text: Optional[str] = None
187
+ use_llm: bool = True
188
+
189
+ # ─── App ───
190
+ @asynccontextmanager
191
+ async def lifespan(app: FastAPI):
192
+ yield
193
+
194
+ app = FastAPI(title="ClauseGuard API", version="4.1.0", lifespan=lifespan)
195
+
196
+ ALLOWED_ORIGINS = [
197
+ "https://clauseguardweb.netlify.app",
198
+ "http://localhost:3000",
199
+ "http://localhost:3001",
200
+ ]
201
+ app.add_middleware(
202
+ CORSMiddleware,
203
+ allow_origins=ALLOWED_ORIGINS,
204
+ allow_origin_regex=r"^chrome-extension://.*$",
205
+ allow_credentials=True,
206
+ allow_methods=["*"],
207
+ allow_headers=["*"],
208
+ )
209
+
210
+ @app.get("/health")
211
+ async def health():
212
+ model_status = "ml" if _SHARED_MODULES and cuad_model else "regex"
213
+ ocr_status = get_ocr_status() if _SHARED_MODULES else "unavailable"
214
+ return {
215
+ "status": "ok",
216
+ "model": model_status,
217
+ "version": "4.1.0",
218
+ "shared_modules": _SHARED_MODULES,
219
+ "ocr": ocr_status,
220
+ "features": ["analyze", "compare", "redline", "chat", "ocr"],
221
+ "rag_sessions_active": len(_rag_sessions),
222
+ }
223
+
224
+ @app.post("/api/analyze")
225
+ async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] = Depends(get_current_user)):
226
+ client_ip = _get_client_ip(request)
227
+ if not _check_rate_limit(client_ip):
228
+ raise HTTPException(status_code=429, detail="Rate limit exceeded. Please wait 60 seconds.")
229
+
230
+ text = req.text
231
+ if not text and req.clauses:
232
+ text = "\n\n".join(req.clauses) if isinstance(req.clauses, list) else str(req.clauses)
233
+
234
+ if not text or len(text.strip()) < 50:
235
+ raise HTTPException(status_code=400, detail="Text too short (minimum 50 characters)")
236
+
237
+ # FIX v4.1: Input size validation
238
+ if len(text) > MAX_TEXT_LENGTH:
239
+ raise HTTPException(status_code=400, detail=f"Text too long (max {MAX_TEXT_LENGTH // 1000}KB)")
240
+
241
+ start = time.time()
242
+
243
+ clauses = split_clauses(text)
244
+ if not clauses:
245
+ raise HTTPException(status_code=400, detail="No clauses detected")
246
+
247
+ clause_results = []
248
+ for clause in clauses:
249
+ predictions = classify_cuad(clause)
250
+ if predictions:
251
+ for pred in predictions:
252
+ clause_results.append({
253
+ "text": clause,
254
+ "label": pred["label"],
255
+ "confidence": pred["confidence"],
256
+ "risk": pred["risk"],
257
+ "description": pred["description"],
258
+ "source": pred.get("source", "unknown"),
259
+ })
260
+
261
+ entities = extract_entities(text)
262
+ contradictions = detect_contradictions(clause_results, text)
263
+ risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses))
264
+ obligations = extract_obligations(text)
265
+ compliance = check_compliance(text)
266
+
267
+ # v4.0: Redlining
268
+ analysis_for_redline = {"clauses": clause_results}
269
+ redlines = []
270
+ try:
271
+ redlines = generate_redlines(analysis_for_redline, use_llm=True)
272
+ except Exception as e:
273
+ print(f"[API] Redlining error: {e}")
274
+
275
+ latency = int((time.time() - start) * 1000)
276
+
277
+ results_for_db = []
278
+ for cr in clause_results:
279
+ results_for_db.append({
280
+ "text": cr["text"],
281
+ "categories": [{
282
+ "name": cr["label"],
283
+ "severity": cr["risk"],
284
+ "confidence": cr["confidence"],
285
+ "description": cr["description"],
286
+ }],
287
+ })
288
+
289
+ # RAG indexing with TTL-managed sessions
290
+ session_id = None
291
+ try:
292
+ chunks, embeddings, _status = index_contract(text)
293
+ if chunks and embeddings is not None:
294
+ session_id = uuid.uuid4().hex[:12]
295
+ _store_rag_session(session_id, {
296
+ "chunks": chunks,
297
+ "embeddings": embeddings,
298
+ "analysis": {
299
+ "risk": {"score": risk, "grade": grade, "breakdown": sev_counts},
300
+ "metadata": {"total_clauses": len(clauses), "flagged_clauses": len(clause_results)},
301
+ "clauses": clause_results[:30],
302
+ "entities": entities[:30],
303
+ "contradictions": contradictions,
304
+ },
305
+ })
306
+ except Exception as e:
307
+ print(f"[API] RAG indexing error: {e}")
308
+
309
+ if user:
310
+ await supabase_insert("analyses", {
311
+ "user_id": user["id"],
312
+ "source_url": req.source_url,
313
+ "total_clauses": len(clauses),
314
+ "flagged_count": len(set(cr["text"] for cr in clause_results)),
315
+ "risk_score": risk,
316
+ "grade": grade,
317
+ "clauses": results_for_db,
318
+ "entities": entities,
319
+ "contradictions": contradictions,
320
+ "obligations": obligations,
321
+ "compliance": compliance,
322
+ })
323
+
324
+ return {
325
+ "risk_score": risk,
326
+ "grade": grade,
327
+ "total_clauses": len(clauses),
328
+ "flagged_count": len(set(cr["text"] for cr in clause_results)),
329
+ "results": results_for_db,
330
+ "entities": entities,
331
+ "contradictions": contradictions,
332
+ "obligations": obligations,
333
+ "compliance": compliance,
334
+ "redlines": redlines,
335
+ "model": "ml" if cuad_model else "regex",
336
+ "latency_ms": latency,
337
+ "session_id": session_id,
338
+ }
339
+
340
+ @app.post("/api/compare")
341
+ async def compare(req: CompareRequest, request: Request):
342
+ client_ip = _get_client_ip(request)
343
+ if not _check_rate_limit(client_ip):
344
+ raise HTTPException(status_code=429, detail="Rate limit exceeded.")
345
+
346
+ # FIX v4.1: Input size validation for comparison
347
+ if len(req.text_a) > MAX_TEXT_LENGTH or len(req.text_b) > MAX_TEXT_LENGTH:
348
+ raise HTTPException(status_code=400, detail=f"Text too long (max {MAX_TEXT_LENGTH // 1000}KB per contract)")
349
+
350
+ return compare_contracts(req.text_a, req.text_b)
351
+
352
+ @app.post("/api/redline")
353
+ async def redline(req: RedlineRequest, request: Request):
354
+ client_ip = _get_client_ip(request)
355
+ if not _check_rate_limit(client_ip):
356
+ raise HTTPException(status_code=429, detail="Rate limit exceeded.")
357
+
358
+ if req.session_id and req.session_id in _rag_sessions:
359
+ analysis = _rag_sessions[req.session_id]["analysis"]
360
+ elif req.text:
361
+ if len(req.text) > MAX_TEXT_LENGTH:
362
+ raise HTTPException(status_code=400, detail="Text too long")
363
+ result, error = analyze_contract(req.text)
364
+ if error:
365
+ raise HTTPException(status_code=400, detail=error)
366
+ analysis = result
367
+ else:
368
+ raise HTTPException(status_code=400, detail="Provide session_id or text")
369
+
370
+ redlines = generate_redlines(analysis, use_llm=req.use_llm)
371
+ return {"redlines": redlines, "count": len(redlines)}
372
+
373
+ @app.post("/api/chat")
374
+ async def chat(req: ChatRequest, request: Request):
375
+ client_ip = _get_client_ip(request)
376
+ if not _check_rate_limit(client_ip):
377
+ raise HTTPException(status_code=429, detail="Rate limit exceeded.")
378
+
379
+ # FIX v4.1: Clean up expired sessions before checking
380
+ _cleanup_rag_sessions()
381
+
382
+ if req.session_id not in _rag_sessions:
383
+ raise HTTPException(status_code=404, detail="Session expired or not found. Please analyze a contract first.")
384
+
385
+ session = _rag_sessions[req.session_id]
386
+ response_text = ""
387
+ for partial in chat_respond(req.message, req.history or [],
388
+ session["chunks"], session["embeddings"], session["analysis"]):
389
+ response_text = partial
390
+
391
+ return {"response": response_text, "session_id": req.session_id}
392
+
393
+ @app.post("/api/chat/stream")
394
+ async def chat_stream(req: ChatRequest, request: Request):
395
+ client_ip = _get_client_ip(request)
396
+ if not _check_rate_limit(client_ip):
397
+ raise HTTPException(status_code=429, detail="Rate limit exceeded.")
398
+
399
+ _cleanup_rag_sessions()
400
+
401
+ if req.session_id not in _rag_sessions:
402
+ raise HTTPException(status_code=404, detail="Session expired or not found.")
403
+
404
+ session = _rag_sessions[req.session_id]
405
+
406
+ async def generate():
407
+ last = ""
408
+ for partial in chat_respond(
409
+ req.message, req.history or [],
410
+ session["chunks"], session["embeddings"], session["analysis"]
411
+ ):
412
+ delta = partial[len(last):]
413
+ last = partial
414
+ if delta:
415
+ yield f"data: {json.dumps({'delta': delta})}\n\n"
416
+ yield "data: [DONE]\n\n"
417
+
418
+ return StreamingResponse(generate(), media_type="text/event-stream")
419
+
420
+ @app.post("/api/ocr")
421
+ async def ocr_endpoint(file: UploadFile = FastAPIFile(...)):
422
+ if not file.filename or not file.filename.lower().endswith(".pdf"):
423
+ raise HTTPException(status_code=400, detail="Only PDF files supported")
424
+
425
+ # FIX v4.1: Limit upload size (20MB)
426
+ content = await file.read()
427
+ if len(content) > 20 * 1024 * 1024:
428
+ raise HTTPException(status_code=400, detail="File too large (max 20MB)")
429
+
430
+ with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
431
+ tmp.write(content)
432
+ tmp_path = tmp.name
433
+
434
+ try:
435
+ text, error, method = parse_pdf_smart(tmp_path)
436
+ if error:
437
+ raise HTTPException(status_code=400, detail=error)
438
+ return {"text": text, "method": method, "chars": len(text) if text else 0, "filename": file.filename}
439
+ finally:
440
+ os.unlink(tmp_path)
441
+
442
+ @app.post("/api/explain", response_model=ExplainResponse)
443
+ async def explain(req: ExplainRequest, user: dict = Depends(require_auth)):
444
+ desc = DESC_MAP.get(req.category, "Unknown category.")
445
+ legal = "Consult local consumer protection laws."
446
+ recommendation = "Review this clause carefully."
447
+
448
+ if SAULLM_ENDPOINT and HF_API_TOKEN:
449
+ try:
450
+ prompt = (
451
+ f"Analyze this contract clause and explain why it may be risky.\n\n"
452
+ f"Clause: \"{req.clause}\"\nCategory: {req.category}\n\n"
453
+ f"Provide: 1) Plain-English explanation 2) Legal basis 3) Recommendation"
454
+ )
455
+ async with httpx.AsyncClient(timeout=30.0) as client:
456
+ resp = await client.post(
457
+ SAULLM_ENDPOINT,
458
+ json={"inputs": prompt, "parameters": {"max_new_tokens": 300, "temperature": 0.3}},
459
+ headers={"Authorization": f"Bearer {HF_API_TOKEN}"},
460
+ )
461
+ if resp.status_code == 200:
462
+ output = resp.json()
463
+ generated = output[0]["generated_text"] if isinstance(output, list) else output.get("generated_text", "")
464
+ if generated and len(generated) > 50:
465
+ parts = generated.split("\n\n")
466
+ desc = parts[0] if len(parts) > 0 else desc
467
+ legal = parts[1] if len(parts) > 1 else legal
468
+ recommendation = parts[2] if len(parts) > 2 else recommendation
469
+ except Exception:
470
+ pass
471
+
472
+ return ExplainResponse(clause=req.clause, category=req.category,
473
+ explanation=desc, legal_basis=legal, recommendation=recommendation)
474
+
475
+ @app.get("/api/history")
476
+ async def history(user: dict = Depends(require_auth), limit: int = 20, offset: int = 0):
477
+ limit = min(limit, 100)
478
+ data = await supabase_query("analyses", {
479
+ "user_id": f"eq.{user['id']}", "select": "*",
480
+ "order": "created_at.desc", "limit": str(limit), "offset": str(offset),
481
+ })
482
+ return {"analyses": data, "limit": limit, "offset": offset}
483
+
484
+ if __name__ == "__main__":
485
+ import uvicorn
486
+ uvicorn.run(app, host="0.0.0.0", port=8000)