gaurv007 commited on
Commit
4939516
·
verified ·
1 Parent(s): e696558

v3.0: Upload actual api/main.py content

Browse files
Files changed (1) hide show
  1. api/main.py +339 -1
api/main.py CHANGED
@@ -1 +1,339 @@
1
- /app/clauseguard/api/main.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ClauseGuard — FastAPI Backend v3.0
3
+ ══════════════════════════════════
4
+ FIXED in v3.0:
5
+ • Imports shared modules (no code duplication)
6
+ • Fixed API schema to accept both {text} and {clauses} from extension
7
+ • Added rate limiting
8
+ • Added max text length validation
9
+ • Fixed CORS (removed wildcard)
10
+ • Added proper error responses
11
+ """
12
+
13
+ import os
14
+ import re
15
+ import json
16
+ import time
17
+ from contextlib import asynccontextmanager
18
+ from typing import Optional
19
+ from collections import defaultdict
20
+ from datetime import datetime
21
+
22
+ import httpx
23
+ import numpy as np
24
+ from fastapi import FastAPI, HTTPException, Depends, Body, Request
25
+ from fastapi.middleware.cors import CORSMiddleware
26
+ from pydantic import BaseModel, Field
27
+
28
+ from auth import get_current_user, require_auth
29
+
30
+ # ── Import shared modules ──
31
+ # When deployed, these must be in the same directory or on PYTHONPATH
32
+ import sys
33
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
34
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
35
+
36
+ try:
37
+ from app import (
38
+ split_clauses, classify_cuad, extract_entities,
39
+ detect_contradictions, compute_risk_score,
40
+ CUAD_LABELS, RISK_MAP, DESC_MAP, _model_status,
41
+ cuad_model, cuad_tokenizer
42
+ )
43
+ from obligations import extract_obligations
44
+ from compliance import check_compliance
45
+ from compare import compare_contracts
46
+ _SHARED_MODULES = True
47
+ except ImportError:
48
+ _SHARED_MODULES = False
49
+ print("[API] WARNING: Could not import shared modules, using inline fallbacks")
50
+
51
+ # ─── Config ───
52
+ SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
53
+ SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "")
54
+ HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "")
55
+ SAULLM_ENDPOINT = os.environ.get("SAULLM_ENDPOINT", "")
56
+ MAX_TEXT_LENGTH = int(os.environ.get("MAX_TEXT_LENGTH", "100000")) # 100KB default
57
+
58
+ # ─── Rate Limiting ───
59
+ _rate_limits = {} # ip -> (count, window_start)
60
+ RATE_LIMIT_REQUESTS = 30
61
+ RATE_LIMIT_WINDOW = 60 # seconds
62
+
63
+ def _check_rate_limit(client_ip: str) -> bool:
64
+ now = time.time()
65
+ if client_ip in _rate_limits:
66
+ count, window_start = _rate_limits[client_ip]
67
+ if now - window_start > RATE_LIMIT_WINDOW:
68
+ _rate_limits[client_ip] = (1, now)
69
+ return True
70
+ if count >= RATE_LIMIT_REQUESTS:
71
+ return False
72
+ _rate_limits[client_ip] = (count + 1, window_start)
73
+ return True
74
+ _rate_limits[client_ip] = (1, now)
75
+ return True
76
+
77
+ # ─── Supabase helper ───
78
+ async def supabase_insert(table: str, data: dict):
79
+ if not SUPABASE_URL or not SUPABASE_SERVICE_KEY:
80
+ return
81
+ try:
82
+ async with httpx.AsyncClient() as client:
83
+ await client.post(
84
+ f"{SUPABASE_URL}/rest/v1/{table}",
85
+ json=data,
86
+ headers={
87
+ "apikey": SUPABASE_SERVICE_KEY,
88
+ "Authorization": f"Bearer {SUPABASE_SERVICE_KEY}",
89
+ "Content-Type": "application/json",
90
+ "Prefer": "return=minimal",
91
+ },
92
+ timeout=10.0,
93
+ )
94
+ except Exception:
95
+ pass
96
+
97
+ async def supabase_query(table: str, params: dict, headers_extra: dict = {}):
98
+ if not SUPABASE_URL or not SUPABASE_SERVICE_KEY:
99
+ return []
100
+ try:
101
+ async with httpx.AsyncClient() as client:
102
+ resp = await client.get(
103
+ f"{SUPABASE_URL}/rest/v1/{table}",
104
+ params=params,
105
+ headers={
106
+ "apikey": SUPABASE_SERVICE_KEY,
107
+ "Authorization": f"Bearer {SUPABASE_SERVICE_KEY}",
108
+ **headers_extra,
109
+ },
110
+ timeout=10.0,
111
+ )
112
+ return resp.json() if resp.status_code == 200 else []
113
+ except Exception:
114
+ return []
115
+
116
+ # ─── Request/Response Models ───
117
+ class AnalyzeRequest(BaseModel):
118
+ text: Optional[str] = Field(None, min_length=50)
119
+ clauses: Optional[list] = None # FIXED: accept clauses array from extension
120
+ source_url: Optional[str] = None
121
+
122
+ class AnalyzeResponse(BaseModel):
123
+ risk_score: int
124
+ grade: str
125
+ total_clauses: int
126
+ flagged_count: int
127
+ results: list[dict]
128
+ entities: list[dict]
129
+ contradictions: list[dict]
130
+ obligations: list[dict]
131
+ compliance: dict
132
+ model: str
133
+ latency_ms: int
134
+
135
+ class CompareRequest(BaseModel):
136
+ text_a: str = Field(..., min_length=50)
137
+ text_b: str = Field(..., min_length=50)
138
+
139
+ class ExplainRequest(BaseModel):
140
+ clause: str = Field(..., min_length=10, max_length=2000)
141
+ category: str
142
+
143
+ class ExplainResponse(BaseModel):
144
+ clause: str
145
+ category: str
146
+ explanation: str
147
+ legal_basis: str
148
+ recommendation: str
149
+
150
+ # ─── App ───
151
+ @asynccontextmanager
152
+ async def lifespan(app: FastAPI):
153
+ # Models are loaded when app.py is imported
154
+ yield
155
+
156
+ app = FastAPI(title="ClauseGuard API", version="3.0.0", lifespan=lifespan)
157
+
158
+ # FIXED: No wildcard CORS
159
+ ALLOWED_ORIGINS = [
160
+ "https://clauseguardweb.netlify.app",
161
+ "http://localhost:3000",
162
+ "http://localhost:3001",
163
+ ]
164
+ # Allow chrome extensions
165
+ app.add_middleware(
166
+ CORSMiddleware,
167
+ allow_origins=ALLOWED_ORIGINS,
168
+ allow_origin_regex=r"^chrome-extension://.*$",
169
+ allow_credentials=True,
170
+ allow_methods=["*"],
171
+ allow_headers=["*"],
172
+ )
173
+
174
+ @app.get("/health")
175
+ async def health():
176
+ model_status = "ml" if _SHARED_MODULES and cuad_model else "regex"
177
+ return {
178
+ "status": "ok",
179
+ "model": model_status,
180
+ "version": "3.0.0",
181
+ "shared_modules": _SHARED_MODULES,
182
+ }
183
+
184
+ @app.post("/api/analyze", response_model=AnalyzeResponse)
185
+ async def analyze(req: AnalyzeRequest, request: Request, user: Optional[dict] = Depends(get_current_user)):
186
+ # Rate limiting
187
+ client_ip = request.client.host if request.client else "unknown"
188
+ if not _check_rate_limit(client_ip):
189
+ raise HTTPException(status_code=429, detail="Rate limit exceeded. Try again in 60 seconds.")
190
+
191
+ # FIXED: Accept either text or clauses from extension
192
+ text = req.text
193
+ if not text and req.clauses:
194
+ text = "\n\n".join(req.clauses) if isinstance(req.clauses, list) else str(req.clauses)
195
+
196
+ if not text or len(text.strip()) < 50:
197
+ raise HTTPException(status_code=400, detail="Text too short (minimum 50 characters)")
198
+
199
+ # Max length check
200
+ if len(text) > MAX_TEXT_LENGTH:
201
+ raise HTTPException(status_code=400, detail=f"Text too long (maximum {MAX_TEXT_LENGTH} characters)")
202
+
203
+ start = time.time()
204
+ clauses = split_clauses(text)
205
+ if not clauses:
206
+ raise HTTPException(status_code=400, detail="No clauses detected in document")
207
+
208
+ clause_results = []
209
+ for clause in clauses:
210
+ predictions = classify_cuad(clause)
211
+ if predictions:
212
+ for pred in predictions:
213
+ clause_results.append({
214
+ "text": clause,
215
+ "label": pred["label"],
216
+ "confidence": pred["confidence"],
217
+ "risk": pred["risk"],
218
+ "description": pred["description"],
219
+ "source": pred.get("source", "unknown"),
220
+ })
221
+
222
+ entities = extract_entities(text)
223
+ contradictions = detect_contradictions(clause_results, text)
224
+ risk, grade, sev_counts = compute_risk_score(clause_results, len(clauses))
225
+ obligations = extract_obligations(text)
226
+ compliance = check_compliance(text)
227
+ latency = int((time.time() - start) * 1000)
228
+
229
+ results_for_db = []
230
+ for cr in clause_results:
231
+ results_for_db.append({
232
+ "text": cr["text"],
233
+ "categories": [{
234
+ "name": cr["label"],
235
+ "severity": cr["risk"],
236
+ "confidence": cr["confidence"],
237
+ "description": cr["description"],
238
+ }],
239
+ })
240
+
241
+ if user:
242
+ await supabase_insert("analyses", {
243
+ "user_id": user["id"],
244
+ "source_url": req.source_url,
245
+ "total_clauses": len(clauses),
246
+ "flagged_count": len(set(cr["text"] for cr in clause_results)),
247
+ "risk_score": risk,
248
+ "grade": grade,
249
+ "clauses": results_for_db,
250
+ "entities": entities,
251
+ "contradictions": contradictions,
252
+ "obligations": obligations,
253
+ "compliance": compliance,
254
+ })
255
+
256
+ return AnalyzeResponse(
257
+ risk_score=risk,
258
+ grade=grade,
259
+ total_clauses=len(clauses),
260
+ flagged_count=len(set(cr["text"] for cr in clause_results)),
261
+ results=results_for_db,
262
+ entities=entities,
263
+ contradictions=contradictions,
264
+ obligations=obligations,
265
+ compliance=compliance,
266
+ model="ml" if cuad_model else "regex",
267
+ latency_ms=latency,
268
+ )
269
+
270
+ @app.post("/api/compare")
271
+ async def compare(req: CompareRequest, request: Request):
272
+ client_ip = request.client.host if request.client else "unknown"
273
+ if not _check_rate_limit(client_ip):
274
+ raise HTTPException(status_code=429, detail="Rate limit exceeded.")
275
+ result = compare_contracts(req.text_a, req.text_b)
276
+ return result
277
+
278
+ @app.post("/api/explain", response_model=ExplainResponse)
279
+ async def explain(req: ExplainRequest, user: dict = Depends(require_auth)):
280
+ desc = DESC_MAP.get(req.category, "Unknown category.")
281
+ legal = "Consult local consumer protection laws."
282
+ recommendation = "Review this clause carefully. Consider negotiating or seeking legal advice before agreeing."
283
+
284
+ if SAULLM_ENDPOINT and HF_API_TOKEN:
285
+ try:
286
+ prompt = (
287
+ f"You are a consumer protection legal analyst. Analyze this contract clause "
288
+ f"and explain why it may be unfair or risky.\n\n"
289
+ f"Clause: \"{req.clause}\"\n"
290
+ f"Category: {req.category}\n\n"
291
+ f"Provide:\n"
292
+ f"1. A plain-English explanation of what this clause means\n"
293
+ f"2. The specific legal basis or consumer protection concern\n"
294
+ f"3. A practical recommendation\n\n"
295
+ f"Be concise. 3-4 sentences per section."
296
+ )
297
+ async with httpx.AsyncClient(timeout=30.0) as client:
298
+ resp = await client.post(
299
+ SAULLM_ENDPOINT,
300
+ json={"inputs": prompt, "parameters": {"max_new_tokens": 300, "temperature": 0.3}},
301
+ headers={"Authorization": f"Bearer {HF_API_TOKEN}"},
302
+ )
303
+ if resp.status_code == 200:
304
+ output = resp.json()
305
+ generated = output[0]["generated_text"] if isinstance(output, list) else output.get("generated_text", "")
306
+ if generated and len(generated) > 50:
307
+ parts = generated.split("\n\n")
308
+ desc = parts[0] if len(parts) > 0 else desc
309
+ legal = parts[1] if len(parts) > 1 else legal
310
+ recommendation = parts[2] if len(parts) > 2 else recommendation
311
+ except Exception:
312
+ pass
313
+
314
+ return ExplainResponse(
315
+ clause=req.clause,
316
+ category=req.category,
317
+ explanation=desc,
318
+ legal_basis=legal,
319
+ recommendation=recommendation,
320
+ )
321
+
322
+ @app.get("/api/history")
323
+ async def history(user: dict = Depends(require_auth), limit: int = 20, offset: int = 0):
324
+ limit = min(limit, 100)
325
+ data = await supabase_query(
326
+ "analyses",
327
+ {
328
+ "user_id": f"eq.{user['id']}",
329
+ "select": "*",
330
+ "order": "created_at.desc",
331
+ "limit": str(limit),
332
+ "offset": str(offset),
333
+ },
334
+ )
335
+ return {"analyses": data, "limit": limit, "offset": offset}
336
+
337
+ if __name__ == "__main__":
338
+ import uvicorn
339
+ uvicorn.run(app, host="0.0.0.0", port=8000)